plugin.py 5.20 KiB
import json
from typing import Any, Dict, List
from pymemri.data.itembase import Edge, Item
from pymemri.data.schema import CategoricalPrediction, EmailMessage, Message
from pymemri.data.schema import Model as ModelItem
from pymemri.plugin.pluginbase import PluginBase
from .model import Model
from .utils import item_to_data
class ClassifierPlugin(PluginBase):
    schema_classes = [Message, CategoricalPrediction, ModelItem]
    def __init__(
        self,
        item_type: str = "Message",
        item_service: str = None,
        model_name: str = "test_abcd123_model",
        model_version: str = "0.1",
        isMock: bool = True,
        **kwargs,
        """
        ClassifierPlugin is a plugin that wraps any classifier and handles all communication with the Pod and conversion from/to `Item`s
        Args:
            item_type (str, optional): The Item type this plugin should make predictions on. Defaults to "Message".
            item_service (str, optional): The service of Items this plugin should make predictions on. Defaults to None.
            model_name (str, optional): Name of the model the plugin should use. Defaults to None.
            model_version (str, optional): Version of the model the plugin should use. Defaults to "0.1".
        """
        super().__init__(**kwargs)
        self.batch_size = 512
        self.model_name = model_name
        self.model_version = model_version
        self.query = {"type": item_type}
        if item_service is not None:
            self.query["service"] = item_service
        if isMock is True:
            self.query["isMock"] = True
    def run(self):
        """Run `self.model` on all data in `self.client.search(self.query)`"""
        print("Loading model...")
        self.load_model()
        print(f"Start predicting...")
        for i, item_batch in enumerate(self.client.search_paginate(
            self.query, limit=self.batch_size
        )):
            print(f"Predicting batch {i:<4}")
            item_batch = self.filter_items(item_batch)
            prepared_batch = self.prepare_batch(item_batch)
            predictions = self.model.predict(prepared_batch)
            prediction_items = [self.prediction_to_item(p) for p in predictions]
            self.sync_to_pod(item_batch, prediction_items)
        print("Done")
    def load_model(self):
        self.model = Model(client=self.client)
        # Without model name, do not create a model item
        if self.model_name is None:
            return
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
# Search in pod for existing models with same name and version, # add a new model if it does not exist. model_items = self.client.search( {"type": "Model", "name": self.model_name, "version": self.model_version} ) if model_items: self.model_item = model_items[0] else: self.model_item = ModelItem( name=self.model_name, version=self.model_version ) self.client.create(self.model_item) def filter_items(self, items: List[Item]) -> List[Item]: result = [] for item in items: if not (isinstance(item, EmailMessage) or isinstance(item, Message)): raise NotImplementedError() if item.content: result.append(item) return result def prepare_batch(self, batch: List[Item]) -> List[Any]: """Prepare a list of items for the model. See `utils.item_to_data` for more information. Args: batch (List[Item]): List of Items from the Pod. Returns: List[Any]: List of prepared data. """ return [item_to_data(item, self.client) for item in batch] def prediction_to_item( self, prediction: List[Dict[str, Any]] ) -> CategoricalPrediction: """Converts a prediction returned by self.model to a CategoricalPrediction that can be added to the Pod Args: prediction (List[Dict[str, float]]): List of predictions. For the correct format, see `.model.Model`. Returns: CategoricalPrediction: List of formatted `CategoricalPrediction`s """ # Get the label with the highest score max_label = max(prediction, key=lambda p: p["score"])["label"] return CategoricalPrediction( probs=json.dumps(prediction), value=max_label, source=f"{self.model.name}:{self.model.version}", ) def sync_to_pod(self, items: List[Item], predictions: List[CategoricalPrediction]): """For each item, prediction, add the prediction to the pod and create an edge called 'label' between Item and Prediction. Args: items (List[Item]): [description] predictions (List[Prediction]): [description] """ # Create edges between item and predictions edges = [ Edge(item, prediction, "label") for item, prediction in zip(items, predictions) ] # Create edges between predictions and model if self.model_item is not None: edges += [ Edge(prediction, self.model_item, "model") for prediction in predictions ]
141142143
self.client.bulk_action(create_items=predictions, create_edges=edges)