Commit c9851d82 authored by Eelco van der Wel's avatar Eelco van der Wel :speech_balloon:
Browse files

update finetune notebook

parent 656150bf
Showing with 1 addition and 1 deletion
+1 -1
%% Cell type:markdown id: tags:
# Load a dataset and train your model
Once you have labeled your data on the Memri platform, you can use it to train your model in [this Google Colab notebook](https://colab.research.google.com/drive/189JJ2gLHAtxlmzc5XI3HhB9_VE3fT6DT)*.
In this guide you will:
1. Load a labeled dataset from the POD
2. Train a distilRoBERTa text classifier model on a labelled dataset
3. Upload a trained model to use in a plugin for a data app
> * If you are unfamiliar with Google Colab notebooks, have a look at [this quick intro.](https://colab.research.google.com/)
* Make sure to run the below cells, one by one, in the correct order to avoid errors!
* In this guide we are helping you connect your own personal data from your Memri POD, alternitively you can use the [Tweet eval emoji](https://huggingface.co/datasets/tweet_eval#source-data) datasets, which is available from 🤗 [Hugging Face](https://huggingface.co/docs/datasets/index.html).
* If you don't wish to use your personal data, or you don't want to spend time training a model, you can simply use our [sentiment-plugin](https://gitlab.memri.io/koenvanderveen/sentiment-plugin/-/packages/6), which uses a pre-trained model from 🤗 Hugging Face. Just paste the plugin repo address at project set-up step on the Memri platform, and skip this process.
%% Cell type:markdown id: tags:
## Setup
%% Cell type:code id: tags:
```
from IPython.display import clear_output
!pip install pandas transformers torch git+https://gitlab.memri.io/memri/pymemri.git@dev
!pip install pandas transformers torch git+https://gitlab.memri.io/memri/pymemri.git@v0.0.29
clear_output()
print("Installed")
```
%% Cell type:markdown id: tags:
1. Import the libraries needed to train your model
> * Make sure to run the installation step above first to avoid errors!
%% Cell type:code id: tags:
```
import os
import random
import textwrap
import pandas as pd
import torch
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.utils import logging
from pymemri.data.itembase import Edge, Item
from pymemri.data.schema import Dataset, Message, CategoricalLabel
from pymemri.data.oauth import OauthFlow
from pymemri.data.loader import write_model_to_package_registry
from pymemri.pod.client import PodClient
from getpass import getpass
transformers.utils.logging.set_verbosity_error()
os.environ["WANDB_DISABLED"] = "true"
```
%% Cell type:markdown id: tags:
## 1. Load your dataset from the POD
%% Cell type:markdown id: tags:
1. Run the cell
2. Copy your Dataset Name, Login Key and Password Key from your app.memri.io screen, and paste them below as prompted to load your connected dataset from you POD.
%% Cell type:code id: tags:
```
### *Define your pod url here*, this is the one for dev.app.memri.io ####
pod_url = "https://dev.pod.memri.io"
### *Define your dataset here* ####
dataset_name = input("dataset_name:") if "dataset_name" not in locals() else dataset_name
### *Define your login key here* ####
owner_key = getpass("owner key:") if "owner_key" not in locals() else owner_key
### *Define your password key here* ####
database_key = getpass("database_key:") if "database_key" not in locals() else database_key
```
%% Cell type:markdown id: tags:
2. Connect your POD to load your data
%% Cell type:code id: tags:
```
# Connect to pod
client = PodClient(
url=pod_url,
owner_key=owner_key,
database_key=database_key,
)
client.add_to_schema(CategoricalLabel, Message, Dataset, OauthFlow);
```
%% Cell type:markdown id: tags:
3. Download and inspect the dataset
> * All entries in the dataset can be found via the Dataset.entry edge
%% Cell type:code id: tags:
```
dataset = client.get_dataset(dataset_name)
num_entries = len(dataset.entry)
print(f"number of items in the dataset: {num_entries}")
```
%% Cell type:markdown id: tags:
4. Export the dataset to a format compatible with Python and inspect in a table
> * In pymemri, the `Dataset` class can format your dataset to different datatypes using the `Dataset.to` method; here we will use Pandas.
> * The columns of the dataset are inferred automatically. If you want to use custom columns, you can use the `columns` argument. See the [dataset documentation](https://docs.memri.io/component-architectures/plugins/datasets/) for more info.
%% Cell type:code id: tags:
```
data = dataset.to("pandas")
data.head()
```
%% Cell type:markdown id: tags:
## 2. Fine-tune a model
%% Cell type:markdown id: tags:
1. Configure the distilRoBERTa model on your dataset
> The transformers library contains all code to do the training, you only need to define a torch Dataset that contains our data and handles tokenization.
%% Cell type:code id: tags:
```
# Hyperparameters
model_name = "distilroberta-base"
batch_size = 32
learning_rate = 1e-3
class TransformerDataset(torch.utils.data.Dataset):
def __init__(self, data: pd.DataFrame, tokenizer: transformers.PreTrainedTokenizerBase):
self.data = data
self.label2idx, self.idx2label = self.get_label_map()
self.num_labels = len(self.label2idx)
self.tokenizer = tokenizer
def tokenize(self, message, label=None):
tokenized = self.tokenizer(message, padding="max_length", truncation=True)
if label:
tokenized["label"] = self.label2idx[label]
return tokenized
def get_label_map(self):
unique_labels = data["annotation.labelValue"].unique()
return {l: i for i, l in enumerate(unique_labels)}, {i: l for i, l in enumerate(unique_labels)}
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# Get the row from self.data, and skip the first column (id).
return self.tokenize(*self.data.iloc[idx][1:])
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = TransformerDataset(data, tokenizer)
```
%% Cell type:markdown id: tags:
2. Train and finetune the model
> * The 🤗 Transformers library provides all code needed to train a RoBERTa model. Read their [tutorial on fine-tuning models](https://huggingface.co/docs/transformers/training)
* We use Trainer class, as it handles all training, monitoring and integration with [Weights & Biases](https://wandb.ai/site)
%% Cell type:code id: tags:
```
# Load model
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=dataset.num_labels,
id2label=dataset.idx2label
)
# To increase training speed, we will freeze all layers except the classifier head.
for param in model.base_model.parameters():
param.requires_grad = False
training_args = transformers.TrainingArguments(
"twitter-emoji-trainer",
learning_rate=learning_rate,
per_device_train_batch_size=batch_size,
logging_steps=1,
optim="adamw_torch"
)
trainer = transformers.Trainer(
model=model,
args=training_args,
train_dataset=dataset
)
logging.set_verbosity(40)
trainer.train()
```
%% Cell type:markdown id: tags:
## 3. Upload your model to a data app plugin
%% Cell type:markdown id: tags:
Now that your model is trained, it will be uploaded to your new GitLab project.
1. Run the cell
2. Copy and paste the GitLab project name from the your screen on app.memri.io
> * To avoid errors, make sure your GitLab project does not have any full stops in the name/URL
%% Cell type:code id: tags:
```
project_name = input("project name:") if "project_name" not in locals() else project_name
write_model_to_package_registry(model, project_name=project_name, client=client)
```
%% Cell type:markdown id: tags:
That's it! 🎉
You have trained a ML model and made it accesible via the package registry, ready to be used in your data app.
Check out the next step to see how to [build a plugin and deploy a data app](https://docs.memri.io/tutorials/build_a_sentiment_analysis_app/#deploy-your-data-app/).
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment