An error occurred while loading the file. Please try again.
-
Alp Deniz Ogut authored43156e6a
from enum import Enum
from functools import cache
import os
import re
import time
import json
from typing import Set
from dataclasses import dataclass
from nltk.tokenize import sent_tokenize
import nltk
from loguru import logger
from more_itertools import chunked
from pymemri.data.schema import File, Photo, CategoricalPrediction
from pymemri.plugin.pluginbase import PluginBase
import requests
from result import Err, Ok, Result
from rss_importer.config import (
RSS_POD_URL,
RSS_OWNER_KEY,
RSS_DATABASE_KEY,
SEMANTIC_SEARCH_URL,
MEMRI_BOT_URL,
SUMMARIZATION_URL,
SUMMARY_MAX_LENGTH,
SUMMARY_MIN_LENGTH,
SUMMARY_SOURCE,
DEFAULT_FEEDS,
)
from rss_importer.labels import get_tags
from .rss import get_feed, scrape_entries, update_feed
from .schema import RSSEntry, RSSFeed, RSSFeedSummary
class LabelSource(str, Enum):
ZERO_SHOT = "zero-shot"
SEMANTIC_SEARCH = "semantic-search"
@dataclass
class Label:
label: str
score: float
class RSSImporter(PluginBase):
def __init__(
self,
setup_on_start: bool = False,
max_entries_on_start: int | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.add_to_schema()
if setup_on_start:
self.setup_feeds(DEFAULT_FEEDS, max_entries=max_entries_on_start)
def run(self):
pass
def _feed_url_in_pod(self, feed_url: str) -> bool:
"""
Returns True if a feed with the same href exists in the Pod
"""
query = """
7172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
query {
RSSFeed(filter: {href: {eq: "$href"}}) {
href
}
}
"""
res = self.client.api.graphql(query=query, variables={"href": feed_url})
return len(res["data"]) > 0
def setup_feeds(
self, feed_urls: list[tuple[str, str]], max_entries: int | None = None
) -> list[RSSFeed]:
logger.info("Setting up default feeds...")
feeds = []
for feed_url, name in feed_urls:
feed_or_err = self.setup_feed(
url=feed_url, max_entries=max_entries, name=name
)
match feed_or_err:
case Ok(feed):
feeds.append(feed)
case Err(e):
logger.error(f"Could not setup feed {feed_url}: {e}")
return feeds
def setup_feed(
self, *, url: str, max_entries: int | None = None, name: str | None = None
) -> Result[RSSFeed, str]:
if self._feed_url_in_pod(url):
return Err("A feed with the same url already exists in pod")
try:
feed, entries = get_feed(url)
except Exception as e:
logger.exception(e)
return Err(f"Could not fetch feed: {e}")
if max_entries is not None:
entries = entries[:max_entries]
if name is not None:
feed.title = name
logger.debug(f"Importing feed with {len(entries)} entries")
entries = scrape_entries(entries)
self.client.bulk_action(
create_items=[feed] + entries,
create_edges=[feed.add_edge("entry", entry) for entry in entries],
)
return Ok(feed)
def _get_feed_existing_ids(self, feed: RSSFeed) -> Set[str]:
if feed.id is None:
return set()
query = """
query {
RSSFeed(filter: {id: {eq: "$id"}}) {
id
entry(filter: { published: { gt: $timestamp }}) {
externalId
published
}
}
}
"""
# Use PodAPI to not interfere with pod local DB
res = self.client.api.graphql(
query=query, variables={"id": feed.id, "timestamp": time.time()}