Commit a65c98cf authored by Szymon Zimnowoda's avatar Szymon Zimnowoda
Browse files

Sunset Twitter

parent 1136a49f
Showing with 103 additions and 9446 deletions
+103 -9446
......@@ -578,7 +578,7 @@ POD is responsible for managing oauth2 tokens for platforms (like gitlab.memri.i
"payload": {
"scopes": list, platform specific,
"redirectUri": URL to which browser will redirect after user action, string,
"platform": currently "gitlab" and "twitter" only supported
"platform": currently "gitlab" only supported
}
}
```
......
......@@ -210,7 +210,6 @@ pub struct OauthAccessTokenPayload {
#[serde(rename_all = "camelCase")]
pub enum Platforms {
Gitlab,
Twitter,
}
#[derive(Deserialize, Debug)]
......
......@@ -11,7 +11,6 @@ pub const WEBSERVER_URL: &str = "webserverUrl";
pub const CONTAINER_ID: &str = "containerId";
pub const PLUGIN_ALIAS: &str = "pluginAlias";
pub const CONTAINER_IMAGE: &str = "containerImage";
pub const UNIQUE_ID: &str = "uniqueId";
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
pub struct PluginRunItem {
......@@ -27,9 +26,6 @@ pub struct PluginRunItem {
pub status: Option<String>,
/// Human friendly name for the plugin, used while interacting with it's API
pub plugin_alias: Option<String>,
/// Only one plugin with given id can be present at a time, id might be
/// for example twitter account number.
pub unique_id: Option<String>,
/// Name of the class that runs as a plugin
pub plugin_name: String,
/// Choose where to run the model, on CPU, or GPU
......
......@@ -11,7 +11,7 @@ use crate::{
graphql, internal_error,
plugin_auth_crypto::DatabaseKey,
triggers,
triggers::{add_extra_fields, add_extra_fields_for_twitter},
triggers::add_extra_fields,
v5::{
self,
api_model::{
......@@ -100,7 +100,6 @@ pub async fn create_item_tx(
}
add_extra_fields(&mut item, cli, pod_owner)?;
add_extra_fields_for_twitter(tx, &mut item).await?;
let mut res = v5::internal_api::create_item_tx(
tx,
......
......@@ -215,30 +215,15 @@ struct PlatformConfiguration {
}
lazy_static! {
static ref PLATFORM_CONFIG: HashMap<Platforms, PlatformConfiguration> = HashMap::from([
(
Platforms::Gitlab,
PlatformConfiguration {
client_id_env: "GITLAB_CLIENT_ID",
client_secret_env: "GITLAB_CLIENT_SECRET",
auth_url: AuthUrl::new("https://gitlab.memri.io/oauth/authorize".to_string())
.unwrap(),
token_url: TokenUrl::new("https://gitlab.memri.io/oauth/token".to_string())
.unwrap(),
},
),
(
Platforms::Twitter,
PlatformConfiguration {
client_id_env: "TWITTER_V2_CLIENT_ID",
client_secret_env: "TWITTER_V2_CLIENT_SECRET",
auth_url: AuthUrl::new("https://twitter.com/i/oauth2/authorize".to_string())
.unwrap(),
token_url: TokenUrl::new("https://api.twitter.com/2/oauth2/token".to_string())
.unwrap(),
},
)
]);
static ref PLATFORM_CONFIG: HashMap<Platforms, PlatformConfiguration> = HashMap::from([(
Platforms::Gitlab,
PlatformConfiguration {
client_id_env: "GITLAB_CLIENT_ID",
client_secret_env: "GITLAB_CLIENT_SECRET",
auth_url: AuthUrl::new("https://gitlab.memri.io/oauth/authorize".to_string()).unwrap(),
token_url: TokenUrl::new("https://gitlab.memri.io/oauth/token".to_string()).unwrap(),
},
),]);
}
async fn exchange_refresh_token(oauth2_item: &Oauth2Flow) -> Result<BasicTokenResponse> {
......
......@@ -25,7 +25,7 @@ use crate::v5::api_model::types::ItemId;
use md5;
use serde::Deserialize;
use serde_json::{json, Value};
use std::{collections::HashMap, ffi::OsStr, fmt::Debug, sync::atomic::AtomicU32, time::Duration};
use std::{collections::HashMap, fmt::Debug, sync::atomic::AtomicU32, time::Duration};
use tokio::fs::File;
use tokio::io::AsyncWriteExt;
use tokio::process::{Child, Command};
......@@ -33,7 +33,6 @@ use tracing::{debug, info, warn};
use duct;
const UNIQUE_ID: &str = "unique_id";
// TODO: this is barely maintainable, split to strategy pattern: Docker, K8S, sharing probably one trait
/// Run a plugin, making sure that the correct ENV variables and settings are passed
......@@ -95,10 +94,6 @@ pub async fn run_plugin_container(
.await?;
}
if let Err(e) = k8s_enforce_unique_instance(plugin).await {
warn!("Enforcing K8s container unique instance failed because: {e:?}");
}
run_kubernetes_container(
&target_item_json,
pod_owner,
......@@ -109,9 +104,6 @@ pub async fn run_plugin_container(
)
.await
} else {
if let Err(e) = docker_enforce_unique_instance(plugin).await {
warn!("Enforcing docker container unique instance failed because: {e:?}");
}
run_docker_container(
&target_item_json,
pod_owner,
......@@ -211,10 +203,6 @@ async fn run_docker_container(
args.extend_from_slice(&os_specific_args);
if let Some(ref unique_id) = plugin.unique_id {
args.extend_from_slice(&["--label".to_string(), format!("{UNIQUE_ID}={unique_id}")]);
}
if plugin.execute_on == PluginExecutionPlatform::Gpu {
args.extend_from_slice(&["--gpus=all".to_string()]);
}
......@@ -225,9 +213,8 @@ async fn run_docker_container(
run_any_command("docker", &args, &envs, triggered_by_item_id).await
}
static IMPORTERS_PLUGINS: [&str; 3] = [
static IMPORTERS_PLUGINS: [&str; 2] = [
"gitlab.memri.io:5050/memri/plugins/whatsapp-multi-device",
"gitlab.memri.io:5050/memri/plugins/twitter",
"registry.digitalocean.com/polis/memri-plugins",
];
......@@ -455,7 +442,7 @@ async fn kubernetes_run_vanilla_plugin(
}
});
let mut labels = format!(
let labels = format!(
"app={},type=plugin,owner={:x}",
plugin.container_id,
// Owner key exceeds a limit of number of characters that k8s label can keep.
......@@ -463,10 +450,6 @@ async fn kubernetes_run_vanilla_plugin(
md5::compute(pod_owner)
);
if let Some(ref unique_id) = plugin.unique_id {
labels = format!("{labels},{UNIQUE_ID}={unique_id}");
}
let args: Vec<String> = vec![
"run".to_string(),
plugin.container_id.clone(),
......@@ -840,98 +823,6 @@ async fn check_kubernetes_limits(
}
}
/// Look for containers for given unique_id, kill them if any exist.
/// Currently it's used for mitigating rate limiting twitter api has.
async fn k8s_enforce_unique_instance(plugin: &PluginRunItem) -> Result<()> {
let Some(ref unique_id) = plugin.unique_id else {
return Ok(());
};
info!(
"Going to cleanup k8s containers with unique_id {}...",
unique_id
);
let args = vec![
"delete".to_string(),
"pod".to_string(),
"-l".to_string(),
format!("{UNIQUE_ID}={unique_id}"),
"--ignore-not-found=true".to_string(),
"--now".to_string(),
];
let std_out = run_command("kubectl", args).await?;
debug!("Cleanup result: {std_out}");
Ok(())
}
/// Look for containers for given unique_id, kill them if any exist.
/// Currently it's used for mitigating rate limiting twitter api has.
async fn docker_enforce_unique_instance(plugin: &PluginRunItem) -> Result<()> {
let Some(ref unique_id) = plugin.unique_id else {
return Ok(());
};
info!(
"Going to cleanup docker containers with unique_id {}...",
unique_id
);
// Get containers with given label
let args = vec![
"ps".to_string(),
"--filter".to_string(),
format!("label={UNIQUE_ID}={unique_id}"),
"-q".to_string(),
];
let std_out = run_command("docker", args).await?;
let running_containers: Vec<&str> = std_out.split_terminator('\n').collect();
if !running_containers.is_empty() {
debug!("Running containers with unique_id {unique_id}: {running_containers:?}");
// Stop containers with given label
let args = ["container", "stop", "-t", "1"]
.iter()
.chain(running_containers.iter());
let _ = run_command("docker", args).await?;
}
Ok(())
}
/// Runs command, waits for finish, returns Ok(std_out) or Err(std_err)
async fn run_command<I, S>(command: &str, args: I) -> Result<String>
where
I: IntoIterator<Item = S> + Debug,
S: AsRef<OsStr>,
{
let cmd_with_args = format!("{command} {args:?}");
info!("Running command {cmd_with_args}");
let output = Command::new(command)
.args(args)
.output()
.await
.context_str("failed to start command")?;
if output.status.success() {
let std_out =
std::str::from_utf8(&output.stdout).context_str("failed to retrieve command stdout")?;
Ok(std_out.to_string())
} else {
let std_err =
std::str::from_utf8(&output.stderr).context_str("failed to retrieve command stderr")?;
warn!("Command {cmd_with_args} stderr {std_err}");
Err(internal_error! { "command failed, {std_err}" })
}
}
#[derive(Deserialize, Debug)]
struct K8SPodsInfo {
items: Vec<Item>,
......@@ -1270,7 +1161,6 @@ pub async fn plugin_attach(
container_id: "no-container-id".to_string(),
status: None,
plugin_alias: alias,
unique_id: None,
plugin_name: "no-plugin-name".to_string(),
execute_on: Default::default(),
};
......
......@@ -5,15 +5,13 @@
use std::{collections::HashSet, fmt::Debug};
use crate::{
any_error,
api_model::{CreateItem, Oauth2AccessTokenRequest},
api_model::CreateItem,
async_db_connection::{AsyncConnection, AsyncTx},
bad_request,
command_line_interface::CliOptions,
db_model::{PluginRunItem, CONTAINER_ID, UNIQUE_ID, WEBSERVER_PORT, WEBSERVER_URL},
db_model::{PluginRunItem, CONTAINER_ID, WEBSERVER_PORT, WEBSERVER_URL},
error::{ErrorContext, Result},
internal_api::{self, get_item_tx},
oauth2_api,
plugin_auth_crypto::DatabaseKey,
plugin_run, plugin_trigger,
schema::{SchemaEdge, SchemaItem, ITEM_EDGE_SCHEMA, ITEM_PROPERTY_SCHEMA},
......@@ -21,10 +19,9 @@ use crate::{
};
use crate::async_db_connection::AsyncTx as Tx;
use md5;
use serde::Deserialize;
use serde_json::{json, Value};
use tracing::{debug, error};
use tracing::error;
#[derive(Debug, Eq, PartialEq)]
pub enum SchemaAdditionChange {
......@@ -92,73 +89,6 @@ pub fn add_extra_fields(item: &mut CreateItem, cli: &CliOptions, pod_owner: &str
Ok(())
}
pub async fn add_extra_fields_for_twitter(tx: &Tx, item: &mut CreateItem) -> Result<()> {
debug!("add_extra_fields_for_twitter");
if item._type == "PluginRun" {
debug!("item is PluginRun");
if let Some(plugin_name) = item
.fields
.get("pluginName")
.cloned()
.map(|name| serde_json::from_value::<String>(name).unwrap())
{
debug!("Plugin name is {plugin_name}");
if plugin_name == "TwitterPlugin" {
let payload = Oauth2AccessTokenRequest {
platform: crate::api_model::Platforms::Twitter,
};
let resp = oauth2_api::access_token(tx, payload)
.await
.context_str("while retrieving access token for twitter hack")?;
let twitter_user_id = {
let client = reqwest::Client::new();
let res = client
.get("https://api.twitter.com/2/users/me")
.bearer_auth(resp.access_token)
.send()
.await
.context_str("Failed to verify twitter credentials")?;
debug!("Got twitter response {res:#?}");
if !res.status().is_success() {
return Err(any_error! {
res.status(),
"Failure getting the credentials"
});
}
#[derive(Deserialize, Debug)]
struct TwitterUserResponse {
data: TwitterUserData,
}
#[derive(Deserialize, Debug)]
struct TwitterUserData {
id: String,
}
let res: TwitterUserResponse = res.json().await?;
res.data.id
};
debug!("TT user id: {twitter_user_id}");
let twitter_user_id_md5 = format!("{:x}", md5::compute(twitter_user_id));
debug!("Storing {UNIQUE_ID}={twitter_user_id_md5} md5 hex into PluginRun");
item.fields
.insert(UNIQUE_ID.to_string(), json!(twitter_user_id_md5));
}
}
}
Ok(())
}
#[derive(Default)]
pub struct AfterAction {
pub items_to_delete: HashSet<String>,
......
# from pymemri.data.schema import Tweet, Account
from twitter_v2.schema import Tweet, Account
from pymemri.data.schema.itembase import ItemBase
from pathlib import Path
import json
import random
SEED = None
NUM_TWEETS = 812
NUM_ACCOUNTS = 20
TWEET_FILE = Path(__file__).parent.joinpath("tweets.json")
class TweetGen:
def __init__(self, num_tweets=NUM_TWEETS, num_accounts=NUM_ACCOUNTS, seed=SEED):
random.seed(SEED)
self._tweets = []
self._accounts = []
self._edges = []
self.create_accounts(num_accounts)
self.read_tweets_from_json(num_tweets)
self.connect_tweets_to_authors()
self.create_conversations()
self.dump_bulk_request()
def read_tweets_from_json(self, num_tweets):
with open(TWEET_FILE, "r") as f:
tweets_json = json.load(f)
for id, json_tweet in enumerate(tweets_json):
if id >= num_tweets:
break
json_tweet.update({"externalId": str(id), "conversationId": str(id)})
m_tweet = Tweet.from_json(json_tweet)
self._tweets.append(m_tweet)
def create_accounts(self, num_accounts):
for id in range(num_accounts):
account = Account.from_json({
"id": ItemBase.create_id(),
"externalId": id,
"service": "twitter",
"displayName": f"account{id}",
"handle": f"username{id}",
"description": f"description of account {id}"
})
self._accounts.append(account)
def connect_tweets_to_authors(self):
nr_accounts = len(self._accounts)
indices = [0]
random_idx = random.sample(range(1, len(self._tweets)), nr_accounts - 1)
random_idx.sort()
indices.extend(random_idx)
indices.append(len(self._tweets))
print(f"indices {indices}")
for idx, acc in enumerate(self._accounts):
start = indices[idx]
stop = indices[idx+1]
print(f"tweets from {start} to {stop} belongs to {acc.externalId}")
tweets = self._tweets[start:stop]
for tweet in tweets:
self._edges.append(tweet.add_edge("author", acc))
def create_conversations(self):
conversations = set()
for tweet in self._tweets:
if tweet.conversationId in conversations:
continue
responses = random.randint(0, 10)
replies = random.sample(self._tweets, responses)
for reply in replies:
if reply.conversationId not in conversations:
# creating conversation, mark conversationId as used
conversations.add(tweet.conversationId)
# reply is not yet part of other conversation
reply.conversationId = tweet.conversationId
edge = tweet.add_edge("replies", reply)
self._edges.append(edge)
def get_create_edge_dict(edge):
return {
"_source": edge.source.id,
"_target": edge.target.id,
"_name": edge.name,
}
def dump_bulk_request(self):
items = list(map(lambda acc: acc.to_json(), self._accounts))
items.extend(map(lambda tt: tt.to_json(), self._tweets))
edges = list(map(TweetGen.get_create_edge_dict, self._edges ))
print(f""" "createItems": {json.dumps(items)},""")
print(f""" "createEdges": {json.dumps(edges)}""")
TweetGen()
\ No newline at end of file
This diff is collapsed.
......@@ -89,109 +89,111 @@ class FrontEnd(FastHttpUser):
def start_fake_importer(self):
# NOTE, we want to always start new importers?
pass
# res = self.client.post("/v4/"+self.pod_owner+"/bulk", json={
# "auth": CLIENT_AUTH,
# "payload": {
# "createItems": [
# {
# "type": "PluginRun",
# "id": FAKE_IMPORTER_ID,
# # no retry, enable tcp ka every 60 second
# "containerImage": "gitlab.memri.io:5050/memri/plugins/twitter-v2:fake-tweets-24259636",
# "pluginModule": "twitter_v2",
# "pluginName": "TwitterPlugin",
# "status": "idle",
# "targetItemId": FAKE_IMPORTER_ID
# }
# ]
# }
# }, name=f"{self.pod_owner} {self.__class__.__name__}.start_fake_importer")
res = self.client.post("/v4/"+self.pod_owner+"/bulk", json={
"auth": CLIENT_AUTH,
"payload": {
"createItems": [
{
"type": "PluginRun",
"id": FAKE_IMPORTER_ID,
# no retry, enable tcp ka every 60 second
"containerImage": "gitlab.memri.io:5050/memri/plugins/twitter-v2:fake-tweets-24259636",
"pluginModule": "twitter_v2",
"pluginName": "TwitterPlugin",
"status": "idle",
"targetItemId": FAKE_IMPORTER_ID
}
]
}
}, name=f"{self.pod_owner} {self.__class__.__name__}.start_fake_importer")
# res.raise_for_status()
# # res.raise_for_status()
res = self.client.post("/v4/"+self.pod_owner+"/get_item",
json={
"auth": CLIENT_AUTH,
"payload": FAKE_IMPORTER_ID
}, name=f"{self.pod_owner} {self.__class__.__name__}.get_importer_progress")
# res = self.client.post("/v4/"+self.pod_owner+"/get_item",
# json={
# "auth": CLIENT_AUTH,
# "payload": FAKE_IMPORTER_ID
# }, name=f"{self.pod_owner} {self.__class__.__name__}.get_importer_progress")
progress = res.json()[0].get('progress', 0.0)
while progress < 1.0:
print(f"Importer progress {progress}")
# progress = res.json()[0].get('progress', 0.0)
# while progress < 1.0:
# print(f"Importer progress {progress}")
sleep(1)
res = self.client.post("/v4/"+self.pod_owner+"/get_item",
json={
"auth": CLIENT_AUTH,
"payload": FAKE_IMPORTER_ID
}, name=f"{self.pod_owner} {self.__class__.__name__}.get_importer_progress")
# sleep(1)
# res = self.client.post("/v4/"+self.pod_owner+"/get_item",
# json={
# "auth": CLIENT_AUTH,
# "payload": FAKE_IMPORTER_ID
# }, name=f"{self.pod_owner} {self.__class__.__name__}.get_importer_progress")
progress = res.json()[0].get('progress', 0.0)
# progress = res.json()[0].get('progress', 0.0)
# TODO: add simulation that user leaves, and enters again
def call_classifier(self):
pass
# return
TT_CLASSIFIER = "4c0e3573-98d5-4250-9190-149cd2b27cda"
# TT_CLASSIFIER = "4c0e3573-98d5-4250-9190-149cd2b27cda"
# # TODO: only for local tests, on infra TT classifier is a shared plugin
# # res = self.client.post("/v4/"+self.pod_owner+"/plugin/attach", json={
# # "auth": CLIENT_AUTH,
# # "payload": {
# # "id": TT_CLASSIFIER,
# # "webserverPort": 3000,
# # "webserverUrl": "http://localhost"
# # }
# # }, name=f"{self.pod_owner} {self.__class__.__name__}.call_classifier")
# # res.raise_for_status()
# TODO: only for local tests, on infra TT classifier is a shared plugin
# res = self.client.post("/v4/"+self.pod_owner+"/plugin/attach", json={
# print(f"CALLING classifier")
# res = self.client.post("/v4/"+self.pod_owner+"/plugin/api/call", json={
# "auth": CLIENT_AUTH,
# "payload": {
# "id": TT_CLASSIFIER,
# "webserverPort": 3000,
# "webserverUrl": "http://localhost"
# "method": "POST",
# "endpoint": "/v1/create_topic_model",
# "query": {
# "max_num_tweets": 500,
# "prediction_threshold": 0.7
# },
# "jsonBody": {
# "database_key": key,
# "owner_key": self.pod_owner,
# # "url": "https://65c7-37-30-108-214.eu.ngrok.io"
# }
# }
# }, name=f"{self.pod_owner} {self.__class__.__name__}.call_classifier")
# res.raise_for_status()
print(f"CALLING classifier")
res = self.client.post("/v4/"+self.pod_owner+"/plugin/api/call", json={
"auth": CLIENT_AUTH,
"payload": {
"id": TT_CLASSIFIER,
"method": "POST",
"endpoint": "/v1/create_topic_model",
"query": {
"max_num_tweets": 500,
"prediction_threshold": 0.7
},
"jsonBody": {
"database_key": key,
"owner_key": self.pod_owner,
# "url": "https://65c7-37-30-108-214.eu.ngrok.io"
}
}
}, name=f"{self.pod_owner} {self.__class__.__name__}.call_classifier")
# res.raise_for_status()
print(f"Call classifier result {res.json()}")
# # res.raise_for_status()
# print(f"Call classifier result {res.json()}")
@task(1)
def user_clicks_on_the_tweet(self):
start = timer()
res = self.client.post("/v4/"+self.pod_owner+"/plugin/api/call", json={
"auth": CLIENT_AUTH,
"payload": {
"id": FAKE_IMPORTER_ID,
"method": "POST",
"endpoint": "/v1/import_conversation",
"query": {
"tweet_item_id": 123,
},
"jsonBody": {
}
}
}, name=f"{self.pod_owner} {self.__class__.__name__}.user_clicks_on_the_tweet")
end = timer()
print(f"CALLING importer took {end - start}")
# res.raise_for_status()
pass
# start = timer()
# res = self.client.post("/v4/"+self.pod_owner+"/plugin/api/call", json={
# "auth": CLIENT_AUTH,
# "payload": {
# "id": FAKE_IMPORTER_ID,
# "method": "POST",
# "endpoint": "/v1/import_conversation",
# "query": {
# "tweet_item_id": 123,
# },
# "jsonBody": {
# }
# }
# }, name=f"{self.pod_owner} {self.__class__.__name__}.user_clicks_on_the_tweet")
# end = timer()
# print(f"CALLING importer took {end - start}")
# # res.raise_for_status()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment