plugin.py 14.11 KiB
from enum import Enum
from functools import cache
import os
import re
import time
import json
from typing import Set
from dataclasses import dataclass
from nltk.tokenize import sent_tokenize
import nltk
from loguru import logger
from more_itertools import chunked
from pymemri.data.schema import File, Photo, CategoricalPrediction
from pymemri.plugin.pluginbase import PluginBase
import requests
from result import Err, Ok, Result
from rss_importer.config import (
    RSS_POD_URL,
    RSS_OWNER_KEY,
    RSS_DATABASE_KEY,
    SEMANTIC_SEARCH_URL,
    MEMRI_BOT_URL,
    SUMMARIZATION_URL,
    SUMMARY_MAX_LENGTH,
    SUMMARY_MIN_LENGTH,
    SUMMARY_SOURCE,
    DEFAULT_FEEDS,
from rss_importer.labels import get_tags
from .rss import get_feed, scrape_entries, update_feed
from .schema import RSSEntry, RSSFeed, RSSFeedSummary
class LabelSource(str, Enum):
    ZERO_SHOT = "zero-shot"
    SEMANTIC_SEARCH = "semantic-search"
@dataclass
class Label:
    label: str
    score: float
class RSSImporter(PluginBase):
    def __init__(
        self,
        setup_on_start: bool = False,
        max_entries_on_start: int | None = None,
        **kwargs,
        super().__init__(**kwargs)
        self.add_to_schema()
        if setup_on_start:
            self.setup_feeds(DEFAULT_FEEDS, max_entries=max_entries_on_start)
    def run(self):
        pass
    def _feed_url_in_pod(self, feed_url: str) -> bool:
        """
        Returns True if a feed with the same href exists in the Pod
        """
        query = """
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
query { RSSFeed(filter: {href: {eq: "$href"}}) { href } } """ res = self.client.api.graphql(query=query, variables={"href": feed_url}) return len(res["data"]) > 0 def setup_feeds( self, feed_urls: list[tuple[str, str]], max_entries: int | None = None ) -> list[RSSFeed]: logger.info("Setting up default feeds...") feeds = [] for feed_url, name in feed_urls: feed_or_err = self.setup_feed( url=feed_url, max_entries=max_entries, name=name ) match feed_or_err: case Ok(feed): feeds.append(feed) case Err(e): logger.error(f"Could not setup feed {feed_url}: {e}") return feeds def setup_feed( self, *, url: str, max_entries: int | None = None, name: str | None = None ) -> Result[RSSFeed, str]: if self._feed_url_in_pod(url): return Err("A feed with the same url already exists in pod") try: feed, entries = get_feed(url) except Exception as e: logger.exception(e) return Err(f"Could not fetch feed: {e}") if max_entries is not None: entries = entries[:max_entries] if name is not None: feed.title = name logger.debug(f"Importing feed with {len(entries)} entries") entries = scrape_entries(entries) self.client.bulk_action( create_items=[feed] + entries, create_edges=[feed.add_edge("entry", entry) for entry in entries], ) return Ok(feed) def _get_feed_existing_ids(self, feed: RSSFeed) -> Set[str]: if feed.id is None: return set() query = """ query { RSSFeed(filter: {id: {eq: "$id"}}) { id entry(filter: { published: { gt: $timestamp }}) { externalId published } } } """ # Use PodAPI to not interfere with pod local DB res = self.client.api.graphql( query=query, variables={"id": feed.id, "timestamp": time.time()}
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
) entries_json = res["data"][0].get("entry", []) existing_ids = set( [entry.get("externalId", None) for entry in entries_json] ) - {None} return existing_ids def update_all_feeds(self): feeds = self.client.search( {"type": "RSSFeed", "importIsActive": True}, include_edges=False ) if not feeds: self.setup_feeds(DEFAULT_FEEDS) for feed in feeds: self.update_feed_entries(feed) def update_feed_entries(self, feed: RSSFeed): existing_ids = self._get_feed_existing_ids(feed) updated_feed, new_entries = update_feed(feed, existing_ids) new_entries = scrape_entries(new_entries) self.client.bulk_action( create_items=new_entries, update_items=[updated_feed], create_edges=[feed.add_edge("entry", entry) for entry in new_entries], ) def update_feed(self, feed: RSSFeed): self.client.update_item(feed) def add_to_schema(self): self.client.add_to_schema( RSSFeed, RSSEntry, RSSFeedSummary, Photo, File, CategoricalPrediction ) def delete_feed(self, feed_id: str) -> str: try: feed = self.client.get(feed_id, expanded=False) except ValueError: raise ValueError(f"Item({feed_id}) has already been deleted") if not isinstance(feed, RSSFeed): raise ValueError(f"Item({feed_id}) is not an RSSFeed") feed.importIsActive = False self.client.update_item(feed) if feed.id is None: raise ValueError("Item has no id") return feed.id @cache def get_feeds(self, include_deleted: bool) -> list[RSSFeed]: feeds = self.client.search({"type": "RSSFeed"}, include_edges=False) if feeds is None: return [] if include_deleted: return feeds else: return [feed for feed in feeds if feed.importIsActive] def summarize_rss_entries( self, entries: list[RSSEntry], min_length: int = SUMMARY_MIN_LENGTH, max_length: int = SUMMARY_MAX_LENGTH, batch_size: int = 4, ) -> list[RSSEntry]: logger.info(f"Summarizing {len(entries)} entries in batches of {batch_size}")
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
for batch in chunked(entries, batch_size): logger.debug(f"Summarizing {len(batch)} entries") documents = [entry.plain_text() for entry in batch] summaries = summary_request( documents, min_length=min_length, max_length=max_length, ) for entry, summary in zip(batch, summaries): if summary: entry.summary = bulletize_summary(summary) entry.summarySource = SUMMARY_SOURCE logger.debug(f"Updating {len(batch)} entries") self.client.bulk_action(update_items=batch) return entries def label_rss_entries( self, entries: list[RSSEntry], batch_size: int = 32, ) -> list[RSSEntry]: label_source = LabelSource.SEMANTIC_SEARCH logger.info(f"Labeling {len(entries)} entries in batches of {batch_size}") for batch in chunked(entries, batch_size): logger.debug(f"Labeling {len(batch)} entries") titles = [entry.title or "" for entry in batch] labels_of_docs = semantic_search_request(titles) create_edges = [] create_items = [] for entry, k_labels in zip(batch, labels_of_docs): for label in k_labels: category_item = CategoricalPrediction( value=label.label, score=label.score, source=label_source.value, ) edge = entry.add_edge("label", category_item) create_edges.append(edge) create_items.append(category_item) logger.debug(f"Updating {len(batch)} entries") self.client.bulk_action( update_items=batch, create_edges=create_edges, create_items=create_items, ) return entries def index_rss_entries( self, entries: list[RSSEntry], batch_size: int = 32, ): logger.warning( f"Indexing RSS entries: {len(entries)} with batch size {batch_size}" ) for batch in chunked(entries, batch_size): logger.info(f"Batch: {len(batch)}") ids = [e.id for e in batch] texts = [e.plain_text() for e in batch] try: semantic_index_request(ids, texts) # update isIndexed flag for entry in batch: entry.isIndexed = True self.client.bulk_action(update_items=batch) except Exception as e: logger.error(f"Semantic index request failed: {e}") def generate_feed_summary(self, category_text=None):
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
index_query = "impactful, affecting a wide range of people, very recent, published today or yesterday." instructions = """Generate a 4-item markdown bullet list from distinct news articles: * Use asterisks as bullets, ensuring valid markdown format. * Each bullet should be 5-20 words, representing a unique article. * Bullet items should not contain any markup or formatting. * No introduction, conclusion, sub-bullets, or nested items. * Append each with '(source: {Item Id})'. Ensure the output is a one-level bullet list without additional text. """ if category_text is not None and category_text != "all": index_query = f"{index_query} {category_text}" instructions += f" Keep the list strictly bounded by {category_text}. Do not include list items that are not closely related to the tags or categories." print( f"Sending chat request with query: {index_query} and instructions: {instructions}" ) messages = [{"role": "USER", "content": instructions}] rss_client_kwargs = { "url": RSS_POD_URL, "owner_key": RSS_OWNER_KEY, "database_key": RSS_DATABASE_KEY, } try: summary = chat_request( "rss_feed", messages, rss_client_kwargs, index_query=index_query ) except Exception as e: logger.error(f"Error during chat request stream: {e}") return None # strip out the lines without bullet points (*) summary = re.sub(r"^[^*]+", "", summary, flags=re.MULTILINE) return summary def feed_chat(self, index_name, messages, category_text=None): index_query = messages[-1]["content"] if category_text is not None and category_text != "all": index_query = f"{index_query} {category_text}" print(f"Sending chat request with query: {index_query}") rss_client_kwargs = { "url": RSS_POD_URL, "owner_key": RSS_OWNER_KEY, "database_key": RSS_DATABASE_KEY, } try: return chat_request_stream( index_name, messages, rss_client_kwargs, index_query=index_query ) except Exception as e: logger.error(f"Error during chat request stream: {e}") return None def bulletize_summary(summary: str) -> str: """Converts a summary to a bulletized list per sentence.""" nltk.download("punkt", quiet=True) sentences = sent_tokenize(summary) return "\n".join([f" • {sentence}" for sentence in sentences]) def summary_request( documents: list[str], min_length: int, max_length: int ) -> list[str]: response = requests.post( SUMMARIZATION_URL, json={ "documents": documents, "min_length": min_length, "max_length": max_length, }, timeout=60,
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
) response.raise_for_status() return response.json()["summaries"] def chat_request(index_name, messages, pod_client_kwargs, index_query=None): data: dict = { "index_name": index_name, "messages": messages, "pod_client_kwargs": pod_client_kwargs, } if index_query: data["index_query"] = index_query reply = "" for r in requests.post( f"{MEMRI_BOT_URL}/v1/memory/chat", json=data, timeout=120, stream=True ).iter_lines(): if r: try: reply_line = json.loads(r)["choices"][0]["text"] reply += reply_line except: logger.error(f"JSON decoding failed for LLM reply: {r}") return reply def chat_request_stream(index_name, messages, pod_client_kwargs, index_query=None): data: dict = { "index_name": index_name, "messages": messages, "pod_client_kwargs": pod_client_kwargs, } if index_query: data["index_query"] = index_query for r in requests.post( f"{MEMRI_BOT_URL}/v1/memory/chat", json=data, timeout=120, stream=True ).iter_lines(): if r: yield r + b"\n" def semantic_index_request(ids: list[str], documents: list[str]): endpoint = os.path.join(SEMANTIC_SEARCH_URL, "index_texts") response = requests.post( endpoint, json={ "index_name": "rss_feed", "ids": ids, "texts": documents, "pod_client_kwargs": { "url": RSS_POD_URL, "owner_key": RSS_OWNER_KEY, "database_key": RSS_DATABASE_KEY, }, "is_chunked": True, }, timeout=60, ) response.raise_for_status() return True def semantic_search_request(documents: list[str]) -> list[list[Label]]: labels = get_tags() endpoint = os.path.join(SEMANTIC_SEARCH_URL, "instant_search") top_k = 3 threshold = 0.2 response = requests.post(