Commit 7af0a218 authored by Eelco van der Wel's avatar Eelco van der Wel :speech_balloon:
Browse files

merge dev

parents 93389ed2 4a91bec6
Pipeline #11433 passed with stage
in 2 minutes and 41 seconds
Showing with 576 additions and 19 deletions
+576 -19
__version__ = "0.0.30"
__version__ = "0.0.32"
......@@ -296,6 +296,8 @@ class MessageChannel(Item):
name: Optional[str] = None
topic: Optional[str] = None
service: Optional[str] = None
isMock: Optional[bool] = None
sourceProject: Optional[str] = None
# Edges
photo: List["Photo"] = []
......@@ -386,6 +388,9 @@ class Plugin(Item):
# Edges
view: List["CVUStoredDefinition"] = []
templateSettings: List["TemplateSettings"] = []
project: List["Project"] = []
run: List["PluginRun"] = []
class PluginRun(Item):
......@@ -406,6 +411,7 @@ class PluginRun(Item):
plugin: List["Plugin"] = []
view: List["CVUStoredDefinition"] = []
account: List["Account"] = []
trigger: List["Trigger"] = []
class Post(Item):
......@@ -415,6 +421,7 @@ class Post(Item):
postDate: Optional[datetime] = None
postType: Optional[str] = None
isMock: Optional[bool] = None
sourceProject: Optional[str] = None
# Edges
author: List["Account"] = []
......@@ -452,6 +459,16 @@ class SuggestedMerge(Item):
mergeFrom: List["Person"] = []
class TemplateSettings(Item):
# Properties
templateName: Optional[str] = None
templateId: Optional[int] = None
dataSource: Optional[str] = None
# Edges
labelOption: List["LabelOption"] = []
class Trigger(Item):
# Properties
action: Optional[str] = None
......
......@@ -6,19 +6,18 @@ from typing import Any, List, Optional, Tuple
import numpy as np
from PIL import Image
from ._central_schema import File, Item
from ._central_schema import File, Photo
DEFAULT_ENCODING = "PNG"
class Photo(Item):
class Photo(Photo):
data: Optional[bytes] = None
height: Optional[int] = None
width: Optional[int] = None
channels: Optional[int] = None
encoding: Optional[str] = None
mode: Optional[str] = None
file: List[File] = []
def show(self):
raise NotImplementedError()
......
......@@ -7,6 +7,8 @@ from pathlib import Path
from fastcore.script import Param, call_parse
from loguru import logger
from pymemri import __version__ as pymemri_version
from ..data.basic import write_json
from ..data.schema import PluginRun
from ..pod.client import DEFAULT_POD_ADDRESS, DEFAULT_POD_KEY_PATH, PodClient
......@@ -18,6 +20,7 @@ from .constants import (
POD_PLUGIN_DNS_ENV,
POD_TARGET_ITEM_ENV,
)
from .oauth_handler import run_twitter_oauth_flow
from .pluginbase import (
PluginError,
create_run_expanded,
......@@ -28,6 +31,8 @@ from .pluginbase import (
)
from .states import RUN_FAILED
logger.info(f"Pymemri version: {pymemri_version}")
def _parse_env():
env = os.environ
......@@ -173,3 +178,34 @@ def simulate_run_plugin_from_frontend(
print("*Check the pod log/console for debug output.*")
return run
@call_parse
def simulate_oauth1_flow(
pod_full_address: Param("The pod full address", str) = DEFAULT_POD_ADDRESS,
port: Param("Port to listen on", int) = 3667,
host: Param("Host to listen on", str) = "localhost",
callback_url: Param("Callback url", str) = None,
database_key: Param("Database key of the pod", str) = None,
owner_key: Param("Owner key of the pod", str) = None,
metadata: Param("metadata file for the PluginRun", str) = None,
):
if database_key is None:
database_key = read_pod_key("database_key")
if owner_key is None:
owner_key = read_pod_key("owner_key")
if metadata is None:
raise ValueError("Missing metadata file")
else:
run = parse_metadata(metadata)
params = [pod_full_address, database_key, owner_key]
if None in params:
raise ValueError("Missing Pod credentials")
print(f"pod_full_address={pod_full_address}\nowner_key={owner_key}\n")
client = PodClient(url=pod_full_address, database_key=database_key, owner_key=owner_key)
if run.pluginName == "TwitterPlugin":
run_twitter_oauth_flow(client=client, host=host, port=port, callback_url=callback_url)
else:
raise ValueError("Unsupported plugin")
import http.server
import socketserver
import urllib
from urllib.parse import parse_qs, urlsplit
from pymemri.data.schema import OauthFlow
from pymemri.pod.client import PodClient
def get_request_handler(
client: PodClient, oauth_token_secret: str
) -> http.server.BaseHTTPRequestHandler:
"""
This is a factory function that returns a request handler class.
The returned class will have a reference to the client and oauth_token_secret
variables that are passed to this function.
This is needed because the request handler class is instantiated by the
TCPServer class, and we need to pass the client and oauth_token_secret
variables to the request handler class.
"""
class MyHttpRequestHandler(http.server.SimpleHTTPRequestHandler):
def do_GET(self):
params = urlsplit(self.path)
if params.path == "/oauth":
args = parse_qs(params.query)
oauth_verifier = args.get("oauth_verifier")[0]
oauth_token = args.get("oauth_token")[0]
response = client.get_oauth1_access_token(
oauth_token=oauth_token,
oauth_verifier=oauth_verifier,
oauth_token_secret=oauth_token_secret,
)
access_token = response["oauth_token"]
access_token_secret = response["oauth_token_secret"]
item = OauthFlow(
service="twitter",
accessToken=access_token,
accessTokenSecret=access_token_secret,
)
client.create(item)
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(bytes("Authenticated, succesfully created oauth item", "utf-8"))
return MyHttpRequestHandler
def run_twitter_oauth_flow(
*,
client: PodClient,
callback_url: str,
host: str = "localhost",
port: int = 3667,
) -> None:
callback_url = callback_url or f"http://{host}:{port}/oauth?state=twitter"
response = client.get_oauth1_request_token("twitter", callback_url)
oauth_token_secret = response["oauth_token_secret"]
queryParameters = {"oauth_token": response["oauth_token"]}
encoded = urllib.parse.urlencode(queryParameters)
print(f"*** \n\nGo to https://api.twitter.com/oauth/authorize?{encoded} \n\n***\n\n")
socketserver.TCPServer.allow_reuse_address = True
my_server = socketserver.TCPServer(
(host, port), get_request_handler(client, oauth_token_secret)
)
client.add_to_schema(OauthFlow)
my_server.handle_request()
import abc
import importlib
import inspect
import json
import warnings
from abc import ABCMeta
......@@ -7,10 +8,10 @@ from abc import ABCMeta
from loguru import logger
from pymemri.data.basic import read_json
from pymemri.webserver.public_api import ENDPOINT_METADATA, register_endpoint
from ..data.basic import write_json
from ..data.schema import Account, PluginRun
from ..pod.client import PodClient
from ..webserver.webserver import WebServer
from .authenticators.credentials import PLUGIN_DIR
from .listeners import get_abort_plugin_listener, get_pod_restart_listener
......@@ -36,12 +37,16 @@ class PluginBase(metaclass=ABCMeta):
self.client = client
self._status_listeners = []
self._config_dict = kwargs
self._daemon = False
if pluginRun is None:
self._webserver = WebServer(8080)
else:
self._webserver = WebServer(pluginRun.webserverPort or 8080)
self.endpoints = self._get_endpoint_methods()
self._register_api_endpoints()
self.set_run_status(RUN_INITIALIZED)
def set_run_status(self, status):
......@@ -50,6 +55,8 @@ class PluginBase(metaclass=ABCMeta):
self.pluginRun.status = status
self.client.update_item(self.pluginRun)
self._status = status
def set_progress(self, progress):
if self.pluginRun and self.client:
self.pluginRun.progress = progress
......@@ -67,13 +74,15 @@ class PluginBase(metaclass=ABCMeta):
for listener in self._status_listeners:
listener.stop()
self._webserver.shutdown()
def _run(self):
self.set_run_status(RUN_STARTED)
self.setup()
self.run()
if self._webserver.is_running():
if self.daemon:
self.set_run_status(RUN_DAEMON)
else:
self.teardown()
......@@ -97,6 +106,73 @@ class PluginBase(metaclass=ABCMeta):
schema.extend(schema_cls.pod_schema())
return schema
def _register_api_endpoints(self):
"""Collect decorated methods and add them to the webserver routes"""
for (path, method), endpoint in self.endpoints.items():
self._webserver.app.add_api_route(path=path, endpoint=endpoint, methods=[method])
def _get_endpoint_methods(self):
"""Collect decorated methods, bind them with `self`, and store in `endpoints`"""
endpoints = {}
for _method_name, method in inspect.getmembers(self, predicate=inspect.ismethod):
if hasattr(method, ENDPOINT_METADATA):
metadata = method.__endpoint_metadata__
if metadata in endpoints:
raise RuntimeError(
f"endpoint {metadata[0]} with method {metadata[1]} is already registered"
)
endpoints[metadata] = method
return endpoints
@register_endpoint("/v1/health", "GET")
def health_endpoint(self):
"""Returns current state of the plugin"""
return self._status
@register_endpoint("/v1/api", "GET")
def get_public_api(self):
"""Returns exposed functions of the plugin"""
def get_friendly_annotation_name(kv):
"""Annotation can be done by class, like: int, list, str
or by alias from typing module, like List, Sequence, Tuple.
"""
k, v = kv
if hasattr(v, "__name__"):
# For classes, use __name__ that returns
# more concise value
return (k, v.__name__)
else:
# The typing aliases do not provide __name__ attribute, so use
# __repr__ implementation.
return (k, str(v))
resp = {
func_name: {
"method": method,
"args": dict(
map(
get_friendly_annotation_name,
inspect.getfullargspec(func).annotations.items(),
)
),
}
for ((func_name, method), func) in self.endpoints.items()
}
return resp
@property
def daemon(self) -> bool:
return self._daemon
@daemon.setter
def daemon(self, daemon: bool):
"""Setting to True will not close the plugin after calling run(), default if False"""
self._daemon = daemon
class PluginError(Exception):
"""Generic class for plugin errors. This error is raised when a plugin raises an unexpected exception."""
......
......@@ -19,6 +19,8 @@ class TriggerPluginBase(PluginBase):
# Pass a closure to the fastapi route
self._webserver.app.add_api_route("/v1/item/trigger", self.do_trigger, methods=["POST"])
self.daemon = True
def do_trigger(self, req: TriggerReq):
"""Handle trigger request for given item. Item must be present already in the POD.
Operation is offloaded to a dedicated thread, the POD is notified about the status
......@@ -27,12 +29,12 @@ class TriggerPluginBase(PluginBase):
def thread_fn(req: TriggerReq):
try:
self.trigger(req)
self.client.send_trigger_status(req.item_id, req.trigger_id, "OK")
self.client.send_trigger_status(req.item_ids, req.trigger_id, "OK")
except Exception as e:
msg = f"Error while handling the trigger for item {req}, reason {e}"
msg = f"Error while handling the trigger {req.trigger_id}, reason {e}"
logger.error(msg)
self.client.send_trigger_status(req.item_id, req.trigger_id, msg)
self.client.send_trigger_status(req.item_ids, req.trigger_id, msg)
threading.Thread(target=thread_fn, args=(req,)).start()
......
......@@ -206,8 +206,8 @@ class PodAPI:
payload = {"to": to, "subject": subject, "body": body}
return self.post("send_email", payload)
def send_trigger_status(self, item_id: str, trigger_id: str, status: str) -> Any:
payload = {"item_id": item_id, "trigger_id": trigger_id, "status": status}
def send_trigger_status(self, item_ids: List[str], trigger_id: str, status: str) -> Any:
payload = {"item_ids": item_ids, "trigger_id": trigger_id, "status": status}
return self.post("trigger/status", payload)
def oauth2get_access_token(self, platform: str) -> Any:
......@@ -215,3 +215,18 @@ class PodAPI:
def plugin_status(self, plugins: List[str]) -> Any:
return self.post("plugin/status", {"plugins": plugins}).json()
def oauth1_request_token(self, platform: str, callback_url: str) -> Any:
return self.post(
"oauth1_request_token", {"service": platform, "callbackUrl": callback_url}
).json()
def oauth1_access_token(self, *, oauth_token, oauth_token_secret, oauth_verifier) -> Any:
return self.post(
"oauth1_access_token",
{
"oauthVerifier": oauth_verifier,
"oauthToken": oauth_token,
"oauthTokenSecret": oauth_token_secret,
},
).json()
......@@ -284,7 +284,7 @@ class PodClient:
delete_items_batch,
)
except PodError as e:
logger.error("could not complete bulk action, aborting")
logger.error(f"could not complete bulk action {e}, aborting")
return False
logger.info(f"Completed Bulk action, written {n} items/edges")
......@@ -328,7 +328,10 @@ class PodClient:
item = self._get_item_with_properties(id)
edges = self.get_edges(id)
for e in edges:
item.add_edge(e["name"], e["item"])
if e["name"] in item.edges:
item.add_edge(e["name"], e["item"])
else:
logger.debug(f"Could not add edge {e['name']}: Edge is not defined on Item.")
return item
def get_edges(self, id):
......@@ -591,3 +594,13 @@ class PodClient:
except PodError as e:
logger.error(e)
return None
def get_oauth1_request_token(self, platform, callback_url):
return self.api.oauth1_request_token(platform, callback_url)
def get_oauth1_access_token(self, oauth_token, oauth_verifier, oauth_token_secret):
return self.api.oauth1_access_token(
oauth_token=oauth_token,
oauth_verifier=oauth_verifier,
oauth_token_secret=oauth_token_secret,
)
from typing import List
from pydantic import BaseModel
class TriggerReq(BaseModel):
item_id: str
item_ids: List[str]
trigger_id: str
from inspect import getfullargspec
ENDPOINT_METADATA = "__endpoint_metadata__"
def register_endpoint(endpoint_name: str, method: str):
"""Registers `fn` under endpoint `endpoint_name` into the webserver as public API.
The `method` parameter specifies HTTP operation: GET, POST, DELETE, etc.
The `fn` needs to pass the validation, for example cannot contain *args, **kwargs,
all arguments needs to have type annotation."""
def _add_endpoint_metadata_attr(fn):
args = getfullargspec(fn)
# *args makes fastapi crash upon calling the handler
if args.varargs:
raise RuntimeError(f"passed function cannot have *args")
# Allowing **kwargs:
# - leaks python detail to the API
# - is treated as string in query parameter: "{"key": value, ...}", instead of {"key": value}
if args.varkw:
raise RuntimeError(f"passed function cannot have **kwargs")
# Requiring typing enforces api clarity
missing_annotations = []
for arg in args.args:
if arg != "self" and arg not in args.annotations:
missing_annotations.append(arg)
if missing_annotations:
raise RuntimeError(f"arguments: {missing_annotations} does not have type annotation")
metadata = (endpoint_name, method)
# Prevent nested decoration
if hasattr(fn, ENDPOINT_METADATA):
raise RuntimeError(
f"Trying to register function {fn.__name__} to {metadata} but is already used in: {fn.__endpoint_metadata__}"
)
fn.__endpoint_metadata__ = metadata
return fn
return _add_endpoint_metadata_attr
......@@ -31,14 +31,15 @@ class WebServer:
return self._app
def run(self):
"""Starts the webserver, only if any route was registered.
"""Starts the webserver, if not done already.
Call returns immediately, server itself is offloaded to
the thread."""
the thread.
"""
# Bare application has two endpoints registered:
# /openapi.json and /docs
DEFAULT_ROUTES = 2
if len(self.app.routes) > DEFAULT_ROUTES:
# Rest is registered by the plugin owner
if not self.is_running():
config = Config(app=self.app, host="0.0.0.0", port=self._port, workers=1)
self._uvicorn = Server(config=config)
......
......@@ -51,3 +51,4 @@ console_scripts =
simulate_enter_credentials = pymemri.plugin.authenticators.password:simulate_enter_credentials
plugin_from_template = pymemri.template.formatter:plugin_from_template
create_plugin_config = pymemri.template.config:create_plugin_config
simulate_oauth1_flow = pymemri.plugin.cli:simulate_oauth1_flow
......@@ -42,7 +42,7 @@ def test_run_from_id(client):
account = Account(identifier="login", secret="password")
run.add_edge("account", account)
assert client.create(run)
client.create(run)
assert client.create(account)
assert client.create_edge(run.get_edges("account")[0])
......
import typing
from datetime import datetime
from time import sleep
from typing import Optional
import pytest
import requests
from pydantic import BaseModel
from pymemri.data.schema import PluginRun
from pymemri.plugin.pluginbase import PluginBase
from pymemri.plugin.states import RUN_INITIALIZED
from pymemri.pod.client import PodClient
from pymemri.webserver.public_api import register_endpoint
class MockPlugin(PluginBase):
def __init__(self):
super().__init__(
client=PodClient(), pluginRun=PluginRun(containerImage="", webserverPort=8080)
)
self.base_endpoint = "http://127.0.0.1:8080"
def run():
pass
def test_endpoints_validation():
with pytest.raises(RuntimeError) as exinfo:
class Plugin(MockPlugin):
@register_endpoint(endpoint_name="/v1/api1", method="POST")
def api1(self, a: int, x, b: list):
pass
assert "arguments: ['x'] does not have type annotation" in str(exinfo.value)
with pytest.raises(RuntimeError) as exinfo:
class Plugin(MockPlugin):
@register_endpoint(endpoint_name="/v1/api1", method="POST")
def api1(self, *args):
pass
assert "passed function cannot have *args" in str(exinfo.value)
with pytest.raises(RuntimeError) as exinfo:
class Plugin(MockPlugin):
@register_endpoint(endpoint_name="/v1/api1", method="POST")
def api1(self, i: int, **kwargs):
pass
assert "passed function cannot have **kwargs" in str(exinfo.value)
def test_cannot_register_function_twice_under_the_same_endpoint_and_method():
class Plugin(MockPlugin):
@register_endpoint(endpoint_name="/v1/api1", method="POST")
def api1(self):
pass
@register_endpoint(endpoint_name="/v1/api1", method="POST")
def api2(self, x: int):
pass
with pytest.raises(RuntimeError) as exinfo:
p = Plugin()
assert "endpoint /v1/api1 with method POST is already registered" in str(exinfo.value)
# It is possible however to have the same endpoint, but different method
class Plugin(MockPlugin):
@register_endpoint(endpoint_name="/v1/api1", method="POST")
def api1(self):
pass
@register_endpoint(endpoint_name="/v1/api1", method="GET")
def api2(self, x: int):
pass
def test_cannot_do_nested_decoration():
with pytest.raises(RuntimeError) as exinfo:
class Plugin(MockPlugin):
@register_endpoint(endpoint_name="/v1/api1", method="POST")
@register_endpoint(endpoint_name="/v1/api1", method="GET")
def api1(self):
pass
assert (
"register function api1 to ('/v1/api1', 'POST') but is already used in: ('/v1/api1', 'GET')"
in str(exinfo.value)
)
def test_can_register_different_endpoints():
# Declaring a class shall not raise an error
class Plugin(MockPlugin):
@register_endpoint("/api1", method="POST")
def function_with_no_args(self):
pass
@register_endpoint("/api2", method="POST")
def function_with_typed_args(self, a: int, b: str, c: float):
pass
@register_endpoint("/api3", method="POST")
def function_with_args_and_return_type(self, a: int, b: str, c: float) -> str:
pass
@register_endpoint("/api4", method="POST")
def function_with_return_type(self) -> str:
pass
@register_endpoint("/api5", method="POST")
def function_with_defaults(self, a: str, b: datetime = None):
pass
@register_endpoint("/api6", method="POST")
def function_with_complex_types(
self, l: typing.List[typing.Optional[int]], t: tuple, s: typing.AnyStr
):
pass
@register_endpoint("/api7", method="POST")
def function_with_boolean(self, flag: bool):
pass
# Instantiating shall not raise either
_plugin = Plugin()
# Integration tests
def _wait_for_webapi(endpoint, attempts=30):
while attempts > 0:
try:
requests.get(f"{endpoint}/v1/api")
return
except Exception as e:
print(f"waiting for server {e}")
attempts -= 1
sleep(1)
raise RuntimeError(f"Cannot connect to plugin's {endpoint} webapi")
def test_webserver_reports_base_api():
# Base class defines v1/health, and v1/api endpoints
class Plugin(MockPlugin):
pass
p = Plugin()
p.setup()
_wait_for_webapi(p.base_endpoint)
resp = requests.get(f"{p.base_endpoint}/v1/api")
assert resp.status_code == 200
assert {
"/v1/api": {"method": "GET", "args": {}},
"/v1/health": {"method": "GET", "args": {}},
} == resp.json()
resp = requests.get(f"{p.base_endpoint}/v1/health")
assert resp.status_code == 200
assert resp.json() == RUN_INITIALIZED
# Shuting down takes some time due to the listeners
p.teardown()
def test_different_types_in_endpoint():
class ReqModel(BaseModel):
x: int
# that allows to set bool to null, or discard from the request body completely
y: Optional[bool]
# pydantic will try to convert to datetime:
# int or float, assumed as Unix time,
# str, following formats work:
# YYYY-MM-DD[T]HH:MM[:SS[.ffffff]][Z or [±]HH[:]MM]]]
# int or float as a string (assumed as Unix time)
d: datetime
class Plugin(MockPlugin):
@register_endpoint(endpoint_name="/v1/echo", method="POST")
def echo(
self,
a: typing.List[typing.AnyStr],
b: dict,
x: int,
req: ReqModel,
t: tuple = (1, 2, 3),
):
return {"a": a, "b": b, "x": x, "req": req, "t": t}
p = Plugin()
p.setup()
_wait_for_webapi(p.base_endpoint)
# Valid call
req = {
"a": ["string", "string"],
"b": {},
"req": {"x": 0, "y": True, "d": "2022-10-24T19:38:09.718Z"},
"t": [1, 2, 3],
}
resp = requests.post(f"{p.base_endpoint}/v1/echo?x=13", json=req)
assert resp.status_code == 200
assert {
"a": ["string", "string"],
"b": {},
"req": {"x": 0, "y": True, "d": "2022-10-24T19:38:09.718000+00:00"},
"t": [1, 2, 3],
"x": 13,
} == resp.json()
# Valid call, removed optional 'y' field
req = {
"a": ["string", "string"],
"b": {},
"req": {
"x": 0,
# "y": True,
"d": "2022-10-24T19:38:09.718Z",
},
"t": [1, 2, 3],
}
resp = requests.post(f"{p.base_endpoint}/v1/echo?x=13", json=req)
assert resp.status_code == 200
assert {
"a": ["string", "string"],
"b": {},
"req": {"x": 0, "y": None, "d": "2022-10-24T19:38:09.718000+00:00"},
"t": [1, 2, 3],
"x": 13,
} == resp.json()
# Invalid method 'get'
resp = requests.get(f"{p.base_endpoint}/v1/echo")
assert resp.status_code == 405
# Invalid request, missing required 'b'
req = {
"a": ["string", "string"],
"req": {"x": 0, "y": True, "d": "2022-10-24T19:38:09.718Z"},
"t": [1, 2, 3],
}
resp = requests.post(f"{p.base_endpoint}/v1/echo?x=13", json=req)
assert resp.status_code == 422
assert "field required" in resp.json()["detail"][0]["msg"]
# Invalid request, violating schema, by introducing non-string element in the List[AnyStr]
req = {
"a": ["string", "string", None],
"b": {},
"req": {"x": 0, "y": True, "d": "2022-10-24T19:38:09.718Z"},
"t": [1, 2, 3],
}
resp = requests.post(f"{p.base_endpoint}/v1/echo?x=13", json=req)
assert resp.status_code == 422
assert "none is not an allowed value" in resp.json()["detail"][0]["msg"]
# Shuting down takes some time due to the listeners
p.teardown()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment