plugin.py 14.09 KiB
from enum import Enum
from functools import cache
import os
import re
import time
import json
from typing import Set
from dataclasses import dataclass
from nltk.tokenize import sent_tokenize
import nltk
from loguru import logger
from more_itertools import chunked
from pymemri.data.schema import File, Photo, CategoricalPrediction
from pymemri.plugin.pluginbase import PluginBase
import requests
from result import Err, Ok, Result
from rss_importer.config import (
    RSS_POD_URL,
    RSS_OWNER_KEY,
    RSS_DATABASE_KEY,
    SEMANTIC_SEARCH_URL,
    MEMRI_BOT_URL,
    SUMMARIZATION_URL,
    SUMMARY_MAX_LENGTH,
    SUMMARY_MIN_LENGTH,
    SUMMARY_SOURCE,
    DEFAULT_FEEDS,
from rss_importer.labels import get_tags
from .rss import get_feed, scrape_entries, update_feed
from .schema import RSSEntry, RSSFeed, RSSFeedSummary
class LabelSource(str, Enum):
    ZERO_SHOT = "zero-shot"
    SEMANTIC_SEARCH = "semantic-search"
@dataclass
class Label:
    label: str
    score: float
class RSSImporter(PluginBase):
    def __init__(
        self,
        setup_on_start: bool = False,
        max_entries_on_start: int | None = None,
        **kwargs,
        super().__init__(**kwargs)
        self.add_to_schema()
        if setup_on_start:
            self.setup_feeds(DEFAULT_FEEDS, max_entries=max_entries_on_start)
    def run(self):
        pass
    def _feed_url_in_pod(self, feed_url: str) -> bool:
        """
        Returns True if a feed with the same href exists in the Pod
        """
        query = """
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
query { RSSFeed(filter: {href: {eq: "$href"}}) { href } } """ res = self.client.api.graphql(query=query, variables={"href": feed_url}) return len(res["data"]) > 0 def setup_feeds( self, feed_urls: list[tuple[str, str]], max_entries: int | None = None ) -> list[RSSFeed]: logger.info("Setting up default feeds...") feeds = [] for feed_url, name in feed_urls: feed_or_err = self.setup_feed( url=feed_url, max_entries=max_entries, name=name ) match feed_or_err: case Ok(feed): feeds.append(feed) case Err(e): logger.error(f"Could not setup feed {feed_url}: {e}") return feeds def setup_feed( self, *, url: str, max_entries: int | None = None, name: str | None = None ) -> Result[RSSFeed, str]: if self._feed_url_in_pod(url): return Err("A feed with the same url already exists in pod") try: feed, entries = get_feed(url) except Exception as e: logger.exception(e) return Err(f"Could not fetch feed: {e}") if max_entries is not None: entries = entries[:max_entries] if name is not None: feed.title = name logger.debug(f"Importing feed with {len(entries)} entries") entries = scrape_entries(entries) self.client.bulk_action( create_items=[feed] + entries, create_edges=[feed.add_edge("entry", entry) for entry in entries], ) return Ok(feed) def _get_feed_existing_ids(self, feed: RSSFeed) -> Set[str]: if feed.id is None: return set() query = """ query { RSSFeed(filter: {id: {eq: "$id"}}) { id entry(filter: { published: { gt: $timestamp }}) { externalId published } } } """ # Use PodAPI to not interfere with pod local DB res = self.client.api.graphql( query=query, variables={"id": feed.id, "timestamp": time.time()}