An error occurred while loading the file. Please try again.
-
Eelco van der Wel authoreda53878f0
import json
from typing import Any, Dict, List
from pymemri.data.itembase import Edge, Item
from pymemri.data.schema import CategoricalPrediction, EmailMessage, Message
from pymemri.data.schema import Model as ModelItem
from pymemri.plugin.pluginbase import PluginBase
from .model import Model
from .utils import item_to_data
class ClassifierPlugin(PluginBase):
schema_classes = [Message, CategoricalPrediction, ModelItem]
def __init__(
self,
item_type: str = "Message",
item_service: str = None,
model_name: str = "test_abcd123_model",
model_version: str = "0.1",
isMock: bool = True,
**kwargs,
):
"""
ClassifierPlugin is a plugin that wraps any classifier and handles all communication with the Pod and conversion from/to `Item`s
Args:
item_type (str, optional): The Item type this plugin should make predictions on. Defaults to "Message".
item_service (str, optional): The service of Items this plugin should make predictions on. Defaults to None.
model_name (str, optional): Name of the model the plugin should use. Defaults to None.
model_version (str, optional): Version of the model the plugin should use. Defaults to "0.1".
"""
super().__init__(**kwargs)
self.batch_size = 512
self.model_name = model_name
self.model_version = model_version
self.query = {"type": item_type}
if item_service is not None:
self.query["service"] = item_service
if isMock is True:
self.query["isMock"] = True
def run(self):
"""Run `self.model` on all data in `self.client.search(self.query)`"""
print("Loading model...")
self.load_model()
print(f"Start predicting...")
for i, item_batch in enumerate(self.client.search_paginate(
self.query, limit=self.batch_size
)):
print(f"Predicting batch {i:<4}")
item_batch = self.filter_items(item_batch)
prepared_batch = self.prepare_batch(item_batch)
predictions = self.model.predict(prepared_batch)
prediction_items = [self.prediction_to_item(p) for p in predictions]
self.sync_to_pod(item_batch, prediction_items)
print("Done")
def load_model(self):
self.model = Model(client=self.client)
# Without model name, do not create a model item
if self.model_name is None:
return
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
# Search in pod for existing models with same name and version,
# add a new model if it does not exist.
model_items = self.client.search(
{"type": "Model", "name": self.model_name, "version": self.model_version}
)
if model_items:
self.model_item = model_items[0]
else:
self.model_item = ModelItem(
name=self.model_name, version=self.model_version
)
self.client.create(self.model_item)
def filter_items(self, items: List[Item]) -> List[Item]:
result = []
for item in items:
if not (isinstance(item, EmailMessage) or isinstance(item, Message)):
raise NotImplementedError()
if item.content:
result.append(item)
return result
def prepare_batch(self, batch: List[Item]) -> List[Any]:
"""Prepare a list of items for the model. See `utils.item_to_data` for more information.
Args:
batch (List[Item]): List of Items from the Pod.
Returns:
List[Any]: List of prepared data.
"""
return [item_to_data(item, self.client) for item in batch]
def prediction_to_item(
self, prediction: List[Dict[str, Any]]
) -> CategoricalPrediction:
"""Converts a prediction returned by self.model to a CategoricalPrediction that can be added to the Pod
Args:
prediction (List[Dict[str, float]]): List of predictions. For the correct format, see `.model.Model`.
Returns:
CategoricalPrediction: List of formatted `CategoricalPrediction`s
"""
# Get the label with the highest score
max_label = max(prediction, key=lambda p: p["score"])["label"]
return CategoricalPrediction(
probs=json.dumps(prediction),
value=max_label,
source=f"{self.model.name}:{self.model.version}",
)
def sync_to_pod(self, items: List[Item], predictions: List[CategoricalPrediction]):
"""For each item, prediction, add the prediction to the pod and create an edge called 'label' between Item and Prediction.
Args:
items (List[Item]): [description]
predictions (List[Prediction]): [description]
"""
# Create edges between item and predictions
edges = [
Edge(item, prediction, "label")
for item, prediction in zip(items, predictions)
]
# Create edges between predictions and model
if self.model_item is not None:
edges += [
Edge(prediction, self.model_item, "model") for prediction in predictions
]
141142143
self.client.bulk_action(create_items=predictions, create_edges=edges)