plugin.py 5.76 KiB
import json
import time
from typing import Any, Dict, List
from loguru import logger
from pymemri.data.schema import CategoricalPrediction, Edge, EmailMessage, Item, Message
from pymemri.data.schema import Model as ModelItem
from pymemri.data.schema import Trigger, Tweet
from pymemri.plugin.trigger_plugin_base import TriggerPluginBase
from pymemri.webserver.models.trigger import TriggerReq
from .model import Model
from .utils import batch_list, item_to_data
TRIGGER_BATCH_SIZE = 16
class ClassifierPlugin(TriggerPluginBase):
    schema_classes = [Message, CategoricalPrediction, ModelItem, Tweet, Trigger]
    def __init__(
        self,
        item_type: str = "Message",
        item_service: str = None,
        model_name: str = "test_release_3_model",
        model_version: str = "0.1",
        isMock: bool = True,
        **kwargs,
        """
        ClassifierPlugin is a plugin that wraps any classifier and handles all communication with the Pod and conversion from/to `Item`s
        Args:
            item_type (str, optional): The Item type this plugin should make predictions on. Defaults to "Message".
            item_service (str, optional): The service of Items this plugin should make predictions on. Defaults to None.
            model_name (str, optional): Name of the model the plugin should use. Defaults to None.
            model_version (str, optional): Version of the model the plugin should use. Defaults to "0.1".
        """
        super().__init__(**kwargs)
        self.model_name = model_name
        self.model_version = model_version
        self.item_type = item_type
        self.query = {"type": item_type}
        if item_service is not None:
            self.query["service"] = item_service
        if isMock is True:
            self.query["isMock"] = True
    def run(self):
        logger.info("Loading model...")
        self.load_model()
        self.set_run_status("daemon")
        logger.info("Done setting up.")
    def predict(self, item_batch: List[Item]):
        item_batch = self.filter_items(item_batch)
        prepared_batch = self.prepare_batch(item_batch)
        predictions = self.model.predict(prepared_batch)
        prediction_items = [self.prediction_to_item(p) for p in predictions]
        self.sync_to_pod(item_batch, prediction_items)
    def load_model(self):
        self.model = Model(client=self.client)
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
# Without model name, do not create a model item if self.model_name is None: return # Search in pod for existing models with same name and version, # add a new model if it does not exist. model_items = self.client.search( {"type": "Model", "name": self.model_name, "version": self.model_version} ) if model_items: self.model_item = model_items[0] else: self.model_item = ModelItem( name=self.model_name, version=self.model_version ) self.client.create(self.model_item) def filter_items(self, items: List[Item]) -> List[Item]: result = [] for item in items: if item_to_data(item, self.client) is not None: result.append(item) return result def prepare_batch(self, batch: List[Item]) -> List[Any]: """Prepare a list of items for the model. See `utils.item_to_data` for more information. Args: batch (List[Item]): List of Items from the Pod. Returns: List[Any]: List of prepared data. """ return [item_to_data(item, self.client) for item in batch] def prediction_to_item( self, prediction: List[Dict[str, Any]] ) -> CategoricalPrediction: """Converts a prediction returned by self.model to a CategoricalPrediction that can be added to the Pod Args: prediction (List[Dict[str, float]]): List of predictions. For the correct format, see `.model.Model`. Returns: CategoricalPrediction: List of formatted `CategoricalPrediction`s """ # Get the label with the highest score max_label = max(prediction, key=lambda p: p["score"])["label"] return CategoricalPrediction( probs=json.dumps(prediction), value=max_label, source=f"{self.model.name}:{self.model.version}", ) def sync_to_pod(self, items: List[Item], predictions: List[CategoricalPrediction]): """For each item, prediction, add the prediction to the pod and create an edge called 'label' between Item and Prediction. Args: items (List[Item]): [description] predictions (List[Prediction]): [description] """ # Create edges between item and predictions edges = [ Edge(item, prediction, "label") for item, prediction in zip(items, predictions) ] # Create edges between predictions and model if self.model_item is not None: edges += [
141142143144145146147148149150151152153154155156157158159160161162163164165166167
Edge(prediction, self.model_item, "model") for prediction in predictions ] self.client.bulk_action(create_items=predictions, create_edges=edges) def _model_has_loaded(self) -> bool: if hasattr(self, 'model'): return True return False def _wait_for_model_has_loaded(self) -> None: while not self._model_has_loaded(): time.sleep(1) def trigger(self, req: TriggerReq): if len(req.item_ids) == 0: return items = self.client.search({"ids": req.item_ids}) self._wait_for_model_has_loaded() if len(items) < len(req.item_ids): logger.warning("Could not find all `item_ids` in Pod from received trigger.") for item_batch in batch_list(items, TRIGGER_BATCH_SIZE): self.predict(item_batch)