-
Alp Deniz Ogut authored07a4d830
from enum import Enum
from functools import cache
import os
import re
import time
import json
from typing import Set
from dataclasses import dataclass
from nltk.tokenize import sent_tokenize
import nltk
from loguru import logger
from more_itertools import chunked
from pymemri.data.schema import File, Photo, CategoricalPrediction
from pymemri.plugin.pluginbase import PluginBase
import requests
from result import Err, Ok, Result
from rss_importer.config import (
RSS_POD_URL,
RSS_OWNER_KEY,
RSS_DATABASE_KEY,
SEMANTIC_SEARCH_URL,
MEMRI_BOT_URL,
SUMMARIZATION_URL,
SUMMARY_MAX_LENGTH,
SUMMARY_MIN_LENGTH,
SUMMARY_SOURCE,
DEFAULT_FEEDS,
)
from rss_importer.labels import get_tags
from .rss import get_feed, scrape_entries, update_feed
from .schema import RSSEntry, RSSFeed, RSSFeedSummary
class LabelSource(str, Enum):
ZERO_SHOT = "zero-shot"
SEMANTIC_SEARCH = "semantic-search"
@dataclass
class Label:
label: str
score: float
class RSSImporter(PluginBase):
def __init__(
self,
setup_on_start: bool = False,
max_entries_on_start: int | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.add_to_schema()
if setup_on_start:
self.setup_feeds(DEFAULT_FEEDS, max_entries=max_entries_on_start)
def run(self):
pass
def _feed_url_in_pod(self, feed_url: str) -> bool:
"""
Returns True if a feed with the same href exists in the Pod
"""
query = """
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
query {
RSSFeed(filter: {href: {eq: "$href"}}) {
href
}
}
"""
res = self.client.api.graphql(query=query, variables={"href": feed_url})
return len(res["data"]) > 0
def setup_feeds(
self, feed_urls: list[tuple[str, str]], max_entries: int | None = None
) -> list[RSSFeed]:
logger.info("Setting up default feeds...")
feeds = []
for feed_url, name in feed_urls:
feed_or_err = self.setup_feed(
url=feed_url, max_entries=max_entries, name=name
)
match feed_or_err:
case Ok(feed):
feeds.append(feed)
case Err(e):
logger.error(f"Could not setup feed {feed_url}: {e}")
return feeds
def setup_feed(
self, *, url: str, max_entries: int | None = None, name: str | None = None
) -> Result[RSSFeed, str]:
if self._feed_url_in_pod(url):
return Err("A feed with the same url already exists in pod")
try:
feed, entries = get_feed(url)
except Exception as e:
logger.exception(e)
return Err(f"Could not fetch feed: {e}")
if max_entries is not None:
entries = entries[:max_entries]
if name is not None:
feed.title = name
logger.debug(f"Importing feed with {len(entries)} entries")
entries = scrape_entries(entries)
self.client.bulk_action(
create_items=[feed] + entries,
create_edges=[feed.add_edge("entry", entry) for entry in entries],
)
return Ok(feed)
def _get_feed_existing_ids(self, feed: RSSFeed) -> Set[str]:
if feed.id is None:
return set()
query = """
query {
RSSFeed(filter: {id: {eq: "$id"}}) {
id
entry(filter: { published: { gt: $timestamp }}) {
externalId
published
}
}
}
"""
# Use PodAPI to not interfere with pod local DB
res = self.client.api.graphql(
query=query, variables={"id": feed.id, "timestamp": time.time()}
141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
)
entries_json = res["data"][0].get("entry", [])
existing_ids = set(
[entry.get("externalId", None) for entry in entries_json]
) - {None}
return existing_ids
def update_all_feeds(self):
feeds = self.client.search(
{"type": "RSSFeed", "importIsActive": True}, include_edges=False
)
if not feeds:
self.setup_feeds(DEFAULT_FEEDS)
for feed in feeds:
self.update_feed_entries(feed)
def update_feed_entries(self, feed: RSSFeed):
existing_ids = self._get_feed_existing_ids(feed)
updated_feed, new_entries = update_feed(feed, existing_ids)
new_entries = scrape_entries(new_entries)
self.client.bulk_action(
create_items=new_entries,
update_items=[updated_feed],
create_edges=[feed.add_edge("entry", entry) for entry in new_entries],
)
def update_feed(self, feed: RSSFeed):
self.client.update_item(feed)
def add_to_schema(self):
self.client.add_to_schema(
RSSFeed, RSSEntry, RSSFeedSummary, Photo, File, CategoricalPrediction
)
def delete_feed(self, feed_id: str) -> str:
try:
feed = self.client.get(feed_id, expanded=False)
except ValueError:
raise ValueError(f"Item({feed_id}) has already been deleted")
if not isinstance(feed, RSSFeed):
raise ValueError(f"Item({feed_id}) is not an RSSFeed")
feed.importIsActive = False
self.client.update_item(feed)
if feed.id is None:
raise ValueError("Item has no id")
return feed.id
@cache
def get_feeds(self, include_deleted: bool) -> list[RSSFeed]:
feeds = self.client.search({"type": "RSSFeed"}, include_edges=False)
if feeds is None:
return []
if include_deleted:
return feeds
else:
return [feed for feed in feeds if feed.importIsActive]
def summarize_rss_entries(
self,
entries: list[RSSEntry],
min_length: int = SUMMARY_MIN_LENGTH,
max_length: int = SUMMARY_MAX_LENGTH,
batch_size: int = 4,
) -> list[RSSEntry]:
logger.info(f"Summarizing {len(entries)} entries in batches of {batch_size}")
211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
for batch in chunked(entries, batch_size):
logger.debug(f"Summarizing {len(batch)} entries")
documents = [entry.plain_text() for entry in batch]
summaries = summary_request(
documents,
min_length=min_length,
max_length=max_length,
)
for entry, summary in zip(batch, summaries):
if summary:
entry.summary = bulletize_summary(summary)
entry.summarySource = SUMMARY_SOURCE
logger.debug(f"Updating {len(batch)} entries")
self.client.bulk_action(update_items=batch)
return entries
def label_rss_entries(
self,
entries: list[RSSEntry],
batch_size: int = 32,
) -> list[RSSEntry]:
label_source = LabelSource.SEMANTIC_SEARCH
logger.info(f"Labeling {len(entries)} entries in batches of {batch_size}")
for batch in chunked(entries, batch_size):
logger.debug(f"Labeling {len(batch)} entries")
titles = [entry.title or "" for entry in batch]
labels_of_docs = semantic_search_request(titles)
create_edges = []
create_items = []
for entry, k_labels in zip(batch, labels_of_docs):
for label in k_labels:
category_item = CategoricalPrediction(
value=label.label,
score=label.score,
source=label_source.value,
)
edge = entry.add_edge("label", category_item)
create_edges.append(edge)
create_items.append(category_item)
logger.debug(f"Updating {len(batch)} entries")
self.client.bulk_action(
update_items=batch,
create_edges=create_edges,
create_items=create_items,
)
return entries
def index_rss_entries(
self,
entries: list[RSSEntry],
batch_size: int = 32,
):
logger.warning(
f"Indexing RSS entries: {len(entries)} with batch size {batch_size}"
)
for batch in chunked(entries, batch_size):
logger.info(f"Batch: {len(batch)}")
ids = [e.id for e in batch]
texts = [e.plain_text() for e in batch]
try:
semantic_index_request(ids, texts)
# update isIndexed flag
for entry in batch:
entry.isIndexed = True
self.client.bulk_action(update_items=batch)
except Exception as e:
logger.error(f"Semantic index request failed: {e}")
def generate_feed_summary(self, category_text=None):
281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
index_query = "impactful, affecting a wide range of people, very recent, published today or yesterday."
instructions = """Generate a 4-item markdown bullet list from distinct news articles:
* Use asterisks as bullets, ensuring valid markdown format.
* Each bullet should be 5-20 words, representing a unique article.
* Bullet items should not contain any markup or formatting.
* No introduction, conclusion, sub-bullets, or nested items.
* Append each with '(source: {Item Id})'.
Ensure the output is a one-level bullet list without additional text.
"""
if category_text is not None and category_text != "all":
index_query = f"{index_query} {category_text}"
instructions += f" Keep the list strictly bounded by {category_text}. Do not include list items that are not closely related to the tags or categories."
print(
f"Sending chat request with query: {index_query} and instructions: {instructions}"
)
messages = [{"role": "USER", "content": instructions}]
rss_client_kwargs = {
"url": RSS_POD_URL,
"owner_key": RSS_OWNER_KEY,
"database_key": RSS_DATABASE_KEY,
}
try:
summary = chat_request(
"rss_feed", messages, rss_client_kwargs, index_query=index_query
)
except Exception as e:
logger.error(f"Error during chat request stream: {e}")
return None
# strip out the lines without bullet points (*)
summary = re.sub(r"^[^*]+", "", summary, flags=re.MULTILINE)
return summary
def feed_chat(self, index_name, messages, category_text=None):
index_query = messages[-1]["content"]
if category_text is not None and category_text != "all":
index_query = f"{index_query} {category_text}"
print(f"Sending chat request with query: {index_query}")
rss_client_kwargs = {
"url": RSS_POD_URL,
"owner_key": RSS_OWNER_KEY,
"database_key": RSS_DATABASE_KEY,
}
try:
return chat_request_stream(
index_name, messages, rss_client_kwargs, index_query=index_query
)
except Exception as e:
logger.error(f"Error during chat request stream: {e}")
return None
def bulletize_summary(summary: str) -> str:
"""Converts a summary to a bulletized list per sentence."""
nltk.download("punkt", quiet=True)
sentences = sent_tokenize(summary)
return "\n".join([f" • {sentence}" for sentence in sentences])
def summary_request(
documents: list[str], min_length: int, max_length: int
) -> list[str]:
response = requests.post(
SUMMARIZATION_URL,
json={
"documents": documents,
"min_length": min_length,
"max_length": max_length,
},
timeout=60,
351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
)
response.raise_for_status()
return response.json()["summaries"]
def chat_request(index_name, messages, pod_client_kwargs, index_query=None):
data: dict = {
"index_name": index_name,
"messages": messages,
"pod_client_kwargs": pod_client_kwargs,
}
if index_query:
data["index_query"] = index_query
reply = ""
for r in requests.post(
f"{MEMRI_BOT_URL}/v1/memory/chat", json=data, timeout=120, stream=True
).iter_lines():
if r:
try:
reply_line = json.loads(r)["choices"][0]["text"]
reply += reply_line
except:
logger.error(f"JSON decoding failed for LLM reply: {r}")
return reply
def chat_request_stream(index_name, messages, pod_client_kwargs, index_query=None):
data: dict = {
"index_name": index_name,
"messages": messages,
"pod_client_kwargs": pod_client_kwargs,
}
if index_query:
data["index_query"] = index_query
for r in requests.post(
f"{MEMRI_BOT_URL}/v1/memory/chat", json=data, timeout=120, stream=True
).iter_lines():
if r:
yield r + b"\n"
def semantic_index_request(ids: list[str], documents: list[str]):
endpoint = os.path.join(SEMANTIC_SEARCH_URL, "index_texts")
response = requests.post(
endpoint,
json={
"index_name": "rss_feed",
"ids": ids,
"texts": documents,
"pod_client_kwargs": {
"url": RSS_POD_URL,
"owner_key": RSS_OWNER_KEY,
"database_key": RSS_DATABASE_KEY,
},
"is_chunked": True,
},
timeout=60,
)
response.raise_for_status()
return True
def semantic_search_request(documents: list[str]) -> list[list[Label]]:
labels = get_tags()
endpoint = os.path.join(SEMANTIC_SEARCH_URL, "instant_search")
top_k = 3
threshold = 0.2
response = requests.post(