plugin.py 4.29 KiB
import pathlib
from typing import Any, List
import json
import numpy as np
from pymemri.data.itembase import Edge, Item
from pymemri.plugin.pluginbase import PluginBase
from pymemri.pod.client import PodClient
from pymemri.data.schema import Message
from transformers import TextClassificationPipeline
from transformers import pipeline as AutoPipeline
import torch
from .schema import CategoricalPrediction, CategoricalLabel
import wandb
def get_predictions(item):
    return [x for x in item.label if isinstance(x, CategoricalPrediction)]
class SentimentAnalysis(PluginBase):
    _default_config = {
        "content_field": "content",
        "item_type": "Message",
        "model_name": "cardiffnlp/twitter-xlm-roberta-base-sentiment",
        "model_head": "eelcovdw/memri_sentiment/model_head:latest",
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        if self.pluginRun and self.pluginRun.settings:
            config = json.loads(self.pluginRun.settings)
        else:
            config = dict()
        self.set_config(config)
    def set_config(self, config):
        for item, default_value in self._default_config.items():
            value = config.get(item, default_value)
            setattr(self, item, value)
    def load_data(self, filter_predicted=True):
        """Load all data from pod of type `self.item_type`.
        Args:
            filter_edge (bool, optional): Remove items from result that already have a prediction.
        Returns:
            List[Item]: List of messages
        """
        data = self.client.search({"type": self.item_type})
        if filter_predicted:
            data = [item for item in data if len(get_predictions(item)) == 0]
        return data
    def load_model_head(self, pipeline: AutoPipeline):
        if self.model_head is None:
            print("Using default model_head")
            return pipeline
        wandb.login(anonymous="must")
        artifact = wandb.Api().artifact(self.model_head)
        mh_path = artifact.download()
        mh_path = next(pathlib.Path(mh_path).glob("*.model"))
        try:
            model_head = torch.load(mh_path, map_location=torch.device("cpu"))
            pipeline.model.classifier = model_head
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
pipeline.model.eval() except Exception as e: print(f"Could not load custom model head {self.model_head}") print("Using default model_head") return pipeline def infer_sentiment(self, data: List[Any], pipeline: TextClassificationPipeline): text_data = [getattr(item, self.content_field) for item in data] sentiments = pipeline(text_data) for i, sentiment in enumerate(sentiments): sentiment = { "labels": [elem["label"].lower() for elem in sentiment], "probs": [elem["score"] for elem in sentiment], } # Set prediction = argmax(probs) argmax = np.argmax(sentiment["probs"]) sentiment["name"] = sentiment["labels"][argmax] sentiment = {k: (json.dumps(v) if isinstance(v, list) else v) for k, v in sentiment.items()} sentiments[i] = CategoricalPrediction(**sentiment) return sentiments def save_to_pod(self, data: List[Item], sentiments: List[Item]): edges = [ Edge(item, sentiment, "label") for item, sentiment in zip(data, sentiments) ] self.client.bulk_action(create_items=sentiments, create_edges=edges) def run(self): print("Loading data from pod...") data = self.load_data(self.client) if len(data) == 0: print("No data found.") print("Run completed.") return print("Loading model...") pipeline = AutoPipeline( "sentiment-analysis", model=self.model_name, tokenizer=self.model_name, return_all_scores=True, ) pipeline = self.load_model_head(pipeline) print("Inferring sentiment...") sentiments = self.infer_sentiment(data, pipeline) print(f"Adding {len(sentiments)} predictions to pod...") self.save_to_pod(data, sentiments) print("Run completed.") def add_to_schema(self): self.client.add_to_schema( Message, CategoricalLabel, CategoricalPrediction )