-
Koen van der Veen authored8ba45b57
import json
import time
from typing import Any, Dict, List
from loguru import logger
from pymemri.data.schema import CategoricalPrediction, Edge, EmailMessage, Item, Message
from pymemri.data.schema import Model as ModelItem
from pymemri.data.schema import Trigger, Tweet
from pymemri.plugin.trigger_plugin_base import TriggerPluginBase
from pymemri.webserver.models.trigger import TriggerReq
from .model import Model
from .utils import batch_list, item_to_data
TRIGGER_BATCH_SIZE = 16
class ClassifierPlugin(TriggerPluginBase):
schema_classes = [Message, CategoricalPrediction, ModelItem, Tweet, Trigger]
def __init__(
self,
item_type: str = "Message",
item_service: str = None,
model_name: str = "test_release_3_model",
model_version: str = "0.1",
isMock: bool = True,
**kwargs,
):
"""
ClassifierPlugin is a plugin that wraps any classifier and handles all communication with the Pod and conversion from/to `Item`s
Args:
item_type (str, optional): The Item type this plugin should make predictions on. Defaults to "Message".
item_service (str, optional): The service of Items this plugin should make predictions on. Defaults to None.
model_name (str, optional): Name of the model the plugin should use. Defaults to None.
model_version (str, optional): Version of the model the plugin should use. Defaults to "0.1".
"""
super().__init__(**kwargs)
self.model_name = model_name
self.model_version = model_version
self.item_type = item_type
self.query = {"type": item_type}
if item_service is not None:
self.query["service"] = item_service
if isMock is True:
self.query["isMock"] = True
def run(self):
logger.info("Loading model...")
self.load_model()
self.set_run_status("daemon")
logger.info("Done setting up.")
def predict(self, item_batch: List[Item]):
item_batch = self.filter_items(item_batch)
prepared_batch = self.prepare_batch(item_batch)
predictions = self.model.predict(prepared_batch)
prediction_items = [self.prediction_to_item(p) for p in predictions]
self.sync_to_pod(item_batch, prediction_items)
def load_model(self):
self.model = Model(client=self.client)
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
# Without model name, do not create a model item
if self.model_name is None:
return
# Search in pod for existing models with same name and version,
# add a new model if it does not exist.
model_items = self.client.search(
{"type": "Model", "name": self.model_name, "version": self.model_version}
)
if model_items:
self.model_item = model_items[0]
else:
self.model_item = ModelItem(
name=self.model_name, version=self.model_version
)
self.client.create(self.model_item)
def filter_items(self, items: List[Item]) -> List[Item]:
result = []
for item in items:
if item_to_data(item, self.client) is not None:
result.append(item)
return result
def prepare_batch(self, batch: List[Item]) -> List[Any]:
"""Prepare a list of items for the model. See `utils.item_to_data` for more information.
Args:
batch (List[Item]): List of Items from the Pod.
Returns:
List[Any]: List of prepared data.
"""
return [item_to_data(item, self.client) for item in batch]
def prediction_to_item(
self, prediction: List[Dict[str, Any]]
) -> CategoricalPrediction:
"""Converts a prediction returned by self.model to a CategoricalPrediction that can be added to the Pod
Args:
prediction (List[Dict[str, float]]): List of predictions. For the correct format, see `.model.Model`.
Returns:
CategoricalPrediction: List of formatted `CategoricalPrediction`s
"""
# Get the label with the highest score
max_label = max(prediction, key=lambda p: p["score"])["label"]
return CategoricalPrediction(
probs=json.dumps(prediction),
value=max_label,
source=f"{self.model.name}:{self.model.version}",
)
def sync_to_pod(self, items: List[Item], predictions: List[CategoricalPrediction]):
"""For each item, prediction, add the prediction to the pod and create an edge called 'label' between Item and Prediction.
Args:
items (List[Item]): [description]
predictions (List[Prediction]): [description]
"""
# Create edges between item and predictions
edges = [
Edge(item, prediction, "label")
for item, prediction in zip(items, predictions)
]
# Create edges between predictions and model
if self.model_item is not None:
edges += [
141142143144145146147148149150151152153154155156157158159160161162163164165166167
Edge(prediction, self.model_item, "model") for prediction in predictions
]
self.client.bulk_action(create_items=predictions, create_edges=edges)
def _model_has_loaded(self) -> bool:
if hasattr(self, 'model'):
return True
return False
def _wait_for_model_has_loaded(self) -> None:
while not self._model_has_loaded():
time.sleep(1)
def trigger(self, req: TriggerReq):
if len(req.item_ids) == 0:
return
items = self.client.search({"ids": req.item_ids})
self._wait_for_model_has_loaded()
if len(items) < len(req.item_ids):
logger.warning("Could not find all `item_ids` in Pod from received trigger.")
for item_batch in batch_list(items, TRIGGER_BATCH_SIZE):
self.predict(item_batch)