Weak supervision¶
This guide gives you a brief introduction to weak supervision with Rubrix.
Rubrix currently supports weak supervision for multi-class text classification use cases, but we’ll be adding support for multilabel text classification and token classification (e.g., Named Entity Recognition) soon.
Rubrix weak supervision in a nutshell¶
The recommended workflow for weak supervision is:
Log an unlabelled dataset into Rubrix
Use the
Annotate
mode for hand- and/or bulk-labelling a test set. This test is key to measure the quality and performance of your rules.Use the
Define Rules
mode for testing and defining rules. Rules are defined with search queries (using ES query string DSL).Use the Python client for reading rules, defining additional rules if needed, and train a label (for building a training set) or a downstream model (for building an end classifier).
The next sections cover the main components of this workflow. If you want to jump into a practical tutorial, check the news classification tutorial.
Weak labeling using the UI¶
Since version 0.8.0 you can find and define rules directly in the UI. The `Define rules
mode <../reference/webapp/define_rules.md>`__ is found below the `Annotate
mode <../reference/webapp/annotate_records.md>`__ on the right sidebar.
The video below shows how you can interactively find and save rules with the UI. For more a full example check the Weak supervision tutorial.
Weak supervision from Python¶
Doing weak supervision with Rubrix should be straightforward. Keeping the same spirit as other parts of the library, you can virtually use any weak supervision library or method, such as Snorkel or Flyingsquid.
Rubrix weak supervision support is built around two basic abstractions:
Rule
¶
A rule encodes an heuristic for labeling a record.
Heuristics can be defined using Elasticsearch’s queries:
plz = Rule(query="plz OR please", label="SPAM")
or with Python functions (similar to Snorkel’s labeling functions, which you can use as well):
def contains_http(record: rb.TextClassificationRecord) -> Optional[str]:
if "http" in record.inputs["text"]:
return "SPAM"
Besides textual features, Python labeling functions can exploit metadata features:
def author_channel(record: rb.TextClassificationRecord) -> Optional[str]:
# the word channel appears in the comment author name
if "channel" in record.metadata["author"]:
return "SPAM"
A rule should either return a string value, that is a weak label, or a None
type in case of abstention.
Weak Labels
¶
Weak Labels objects bundle and apply a set of rules to the records of a Rubrix dataset. Applying a rule to a record means assigning a weak label or abstaining.
This abstraction provides you with the building blocks for training and testing weak supervision “denoising”, “label” or even “end” models:
rules = [contains_http, author_channel]
weak_labels = WeakLabels(
rules=rules,
dataset="weak_supervision_yt"
)
# returns a summary of the applied rules
weak_labels.summary()
More information about these abstractions can be found in the Python Labeling module docs.
Built-in label models¶
To make things even easier for you, we provide wrapper classes around the most common label models, that directly consume a WeakLabels
object. This makes working with those models a breeze. Take a look at the list of built-in models in the labeling module docs.
Detailed Workflow¶
A typical workflow to use weak supervision is:
Create a Rubrix dataset with your raw dataset. If you actually have some labelled data you can log it into the the same dataset.
Define a set of weak labeling rules with the Rules definition mode in the UI.
Create a
WeakLabels
object and apply the rules. You can load the rules from your dataset and add additional rules and labeling functions using Python. Typically, you’ll iterate between this step and step 2.Once you are satisfied with your weak labels, use the matrix of the
WeakLabels
instance with your library/method of choice to build a training set or even train a downstream text classification model.
This guide shows you an end-to-end example using Snorkel and Flyingsquid. Let’s get started!
Example dataset¶
We’ll be using a well-known dataset for weak supervision examples, the YouTube Spam Collection dataset, which is a binary classification task for detecting spam comments in Youtube videos.
[4]:
import pandas as pd
# load data
train_df = pd.read_csv('../tutorials/data/yt_comments_train.csv')
test_df = pd.read_csv('../tutorials/data/yt_comments_test.csv')
# preview data
train_df.head()
[4]:
Unnamed: 0 | author | date | text | label | video | |
---|---|---|---|---|---|---|
0 | 0 | Alessandro leite | 2014-11-05T22:21:36 | pls http://www10.vakinha.com.br/VaquinhaE.aspx... | -1.0 | 1 |
1 | 1 | Salim Tayara | 2014-11-02T14:33:30 | if your like drones, plz subscribe to Kamal Ta... | -1.0 | 1 |
2 | 2 | Phuc Ly | 2014-01-20T15:27:47 | go here to check the views :3 | -1.0 | 1 |
3 | 3 | DropShotSk8r | 2014-01-19T04:27:18 | Came here to check the views, goodbye. | -1.0 | 1 |
4 | 4 | css403 | 2014-11-07T14:25:48 | i am 2,126,492,636 viewer :D | -1.0 | 1 |
1. Create a Rubrix dataset with unlabelled data and test data¶
Let’s load the train (non-labelled) and the test (containing labels) dataset.
[ ]:
import rubrix as rb
# build records from the train dataset
records = [
rb.TextClassificationRecord(
inputs=row.text,
metadata={"video":row.video, "author": row.author}
)
for i,row in train_df.iterrows()
]
# build records from the test dataset with annotation
labels = ["HAM", "SPAM"]
records += [
rb.TextClassificationRecord(
inputs=row.text,
annotation=labels[row.label],
metadata={"video":row.video, "author": row.author}
)
for i,row in test_df.iterrows()
]
# log records to Rubrix
rb.log(records, name="weak_supervision_yt")
After this step, you have a fully browsable dataset available at http://localhost:6900/weak_supervision_yt
(or the base URL where your Rubrix instance is hosted).
2. Defining rules¶
Let’s now define some of the rules proposed in the tutorial Snorkel Intro Tutorial: Data Labeling. Most of these rules can be defined directly in the UI using Elasticsearch’s query string DSL.
Rules can also be defined programmatically as shown below. Depending on your use case and team structure you can mix and match both interfaces (UI or Python).
Let’s see here some programmatic rules:
[ ]:
from rubrix.labeling.text_classification import Rule, WeakLabels
# rules defined as Elasticsearch queries
check_out = Rule(query="check out", label="SPAM")
plz = Rule(query="plz OR please", label="SPAM")
subscribe = Rule(query="subscribe", label="SPAM")
my = Rule(query="my", label="SPAM")
song = Rule(query="song", label="HAM")
love = Rule(query="love", label="HAM")
You can also define plain Python labeling functions:
[ ]:
import re
# rules defined as Python labeling functions
def contains_http(record: rb.TextClassificationRecord):
if "http" in record.inputs["text"]:
return "SPAM"
def short_comment(record: rb.TextClassificationRecord):
return "HAM" if len(record.inputs["text"].split()) < 5 else None
def regex_check_out(record: rb.TextClassificationRecord):
return "SPAM" if re.search(r"check.*out", record.inputs["text"], flags=re.I) else None
3. Building and analizing weak labels¶
[ ]:
from rubrix.labeling.text_classification import load_rules
# bundle our rules in a list
rules = [check_out, plz, subscribe, my, song, love, contains_http, short_comment, regex_check_out]
# optionally add the rules defined in the web app UI
rules += load_rules(dataset="weak_supervision_yt")
# apply the rules to a dataset to obtain the weak labels
weak_labels = WeakLabels(
rules=rules,
dataset="weak_supervision_yt"
)
[12]:
# show some stats about the rules, see the `summary()` docstring for details
weak_labels.summary()
[12]:
polarity | coverage | overlaps | conflicts | correct | incorrect | precision | |
---|---|---|---|---|---|---|---|
check out | {SPAM} | 0.242919 | 0.235839 | 0.029956 | 45 | 0 | 1.000000 |
plz OR please | {SPAM} | 0.090414 | 0.081155 | 0.019608 | 20 | 0 | 1.000000 |
subscribe | {SPAM} | 0.106754 | 0.083878 | 0.028867 | 30 | 0 | 1.000000 |
my | {SPAM} | 0.190632 | 0.166667 | 0.049564 | 41 | 6 | 0.872340 |
song | {HAM} | 0.132898 | 0.079521 | 0.033769 | 39 | 9 | 0.812500 |
love | {HAM} | 0.092048 | 0.070261 | 0.031590 | 28 | 7 | 0.800000 |
contains_http | {SPAM} | 0.106209 | 0.073529 | 0.049564 | 6 | 0 | 1.000000 |
short_comment | {HAM} | 0.245098 | 0.110566 | 0.064270 | 84 | 8 | 0.913043 |
regex_check_out | {SPAM} | 0.226580 | 0.226035 | 0.027778 | 45 | 0 | 1.000000 |
total | {SPAM, HAM} | 0.754902 | 0.448802 | 0.120915 | 338 | 30 | 0.918478 |
4. Using the weak labels¶
At this step you have at least two options:
Use the weak labels for training a “denoising” or label model to build a less noisy training set. Highly popular options for this are Snorkel or Flyingsquid. After this step, you can train a downstream model with the “clean” labels.
Use the weak labels directly with recent “end-to-end” (e.g., Weasel) or joint models (e.g., COSINE).
Let’s see some examples:
Label model with Snorkel¶
Snorkel is by far the most popular option for using weak supervision, and Rubrix provides built-in support for it. Using Snorkel with Rubrix’s WeakLabels
is as simple as:
[ ]:
%pip install snorkel -qqq
[ ]:
from rubrix.labeling.text_classification import Snorkel
# we pass our WeakLabels instance to our Snorkel label model
label_model = Snorkel(weak_labels)
# we fit the model
label_model.fit()
# we check its performance
label_model.score()
After fitting your label model, you can quickly explore its predictions, before building a training set for training a downstream text classifier.
This step is useful for validation, manual revision, or defining score thresholds for accepting labels from your label model (for example, only considering labels with a score greater then 0.8.)
[ ]:
# get your training records with the predictions of the label model
records_for_training = label_model.predict()
# log the records to a new dataset in Rubrix
rb.log(records_for_training, name="snorkel_results")
Label model with FlyingSquid¶
FlyingSquid is a powerful method developed by Hazy Research, a research group from Stanford behind ground-breaking work on programmatic data labeling, including Snorkel. FlyingSquid uses a closed-form solution for fitting the label model with great speed gains and similar performance. Just like for Snorkel, Rubrix provides built-in support for FlyingSquid, too.
[ ]:
%pip install flyingsquid pgmpy -qqq
[ ]:
from rubrix.labeling.text_classification import FlyingSquid
# we pass our WeakLabels instance to our FlyingSquid label model
label_model = FlyingSquid(weak_labels)
# we fit the model
label_model.fit()
# we check its performance
label_model.score()
After fitting your label model, you can quickly explore its predictions, before building a training set for training a downstream text classifier.
This step is useful for validation, manual revision, or defining score thresholds for accepting labels from your label model (for example, only considering labels with a score greater then 0.8.)
[ ]:
# get your training records with the predictions of the label model
records_for_training = label_model.predict()
# log the records to a new dataset in Rubrix
rb.log(records_for_training, name="flyingsquid_results")
Joint Model with Weasel¶
Weasel lets you train downstream models end-to-end using directly weak labels. In contrast to Snorkel or FlyingSquid, which are two-stage approaches, Weasel is a one-stage method that jointly trains the label and the end model at the same time. For more details check out the End-to-End Weak Supervision paper presented at NeurIPS 2021.
In this guide we will show you, how you can train a Hugging Face transformers model directly with weak labels using Weasel. Since Weasel uses PyTorch Lightning for the training, some basic knowledge of PyTorch is helpful, but not strictly necessary.
First, we need to install the Weasel python package:
[ ]:
!python -m pip install git+https://github.com/autonlab/weasel#egg=weasel[all]
Before we get started, we need to define some classes, that wrap our data and our end model in a way Weasel can work with them.
[ ]:
from weasel.datamodules.base_datamodule import AbstractWeaselDataset, AbstractDownstreamDataset
from weasel.models.downstream_models.base_model import DownstreamBaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from torch.utils.data import DataLoader
import torch
class TrainDataset(AbstractWeaselDataset):
def __init__(self, L, inputs):
super().__init__(L, None)
self.inputs = inputs
if self.L.shape[0] != len(self.inputs):
raise ValueError("L and inputs have different number of samples")
def __getitem__(self, item):
return self.L[item], self.inputs[item]
class TestDataset(AbstractDownstreamDataset):
def __init__(self, inputs, Y):
super().__init__(None, Y)
self.inputs = inputs
if len(self.Y) != len(self.inputs):
raise ValueError("inputs and Y have different number of samples")
def __getitem__(self, item):
return self.inputs[item], self.Y[item]
class TrainCollator:
def __init__(self, tokenizer):
self._tokenizer = tokenizer
def __call__(self, batch):
L = torch.stack([b[0] for b in batch])
inputs = {key: [b[1][key] for b in batch] for key in batch[0][1]}
return L, self._tokenizer.pad(inputs, return_tensors="pt")
class TestCollator:
def __init__(self, tokenizer):
self._tokenizer = tokenizer
def __call__(self, batch):
Y = torch.stack([b[1] for b in batch])
inputs = {key: [b[0][key] for b in batch] for key in batch[0][0]}
return self._tokenizer.pad(inputs, return_tensors="pt"), Y
class TransformersEndModel(DownstreamBaseModel):
def __init__(self, name: str, num_labels: int = 2):
super().__init__()
self.out_dim = num_labels
self.model = AutoModelForSequenceClassification.from_pretrained(name, num_labels=num_labels)
def forward(self, kwargs):
model_output = self.model(**kwargs)
return model_output["logits"]
The first step is to obtain our weak labels. For this we use the same rules and data set as in the examples above (Snorkel and FlyingSquid).
[ ]:
# obtain our weak labels
weak_labels = WeakLabels(
rules=rules,
dataset="weak_supervision_yt"
)
In a second step we instantiate our end model, which in our case will be a pre-trained transformer from the Hugging Face Hub. Here we choose the small ELECTRA model by Google that shows excellent performance given its moderate number of parameters. Due to its size, you can fine-tune it on your CPU within a reasonable amount of time.
[ ]:
# instantiate our transformers end model
end_model = TransformersEndModel("google/electra-small-discriminator", num_labels=2)
With our end-model at hand, we can now instantiate the Weasel model. Apart from the end-model, it also includes a neural encoder that tries to estimate latent labels.
[ ]:
from weasel.models import Weasel
# instantiate our weasel end-to-end model
weasel = Weasel(
end_model=end_model,
num_LFs=len(weak_labels.rules),
n_classes=2,
encoder={'hidden_dims': [32, 10]},
optim_encoder={'name': 'adam', 'lr': 1e-4},
optim_end_model={'name': 'adam', 'lr': 5e-5},
)
Afterwards, we wrap our data in torch Dataset
s and DataLoader
s, so that Weasel and PyTorch Lightning can work with it. In this step we also tokenize the data. Here we need to be careful to use the corresponding tokenizer to our end model.
[ ]:
# tokenizer for our transformers end model
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-discriminator")
# torch data set of our training data
train_ds = TrainDataset(
L=weak_labels.matrix(has_annotation=False),
inputs=[tokenizer(rec.inputs["text"], truncation=True)
for rec in weak_labels.records(has_annotation=False)],
)
# torch data set of our test data
test_ds = TestDataset(
inputs=[tokenizer(rec.inputs["text"], truncation=True)
for rec in weak_labels.records(has_annotation=True)],
Y=weak_labels.annotation(),
)
# torch data loader for our training data
train_loader = DataLoader(
dataset=train_ds,
collate_fn=TrainCollator(tokenizer),
batch_size=8,
)
# torch data loader for our test data
test_loader = DataLoader(
dataset=test_ds,
collate_fn=TestCollator(tokenizer),
batch_size=16,
)
Now we have everything ready to start the training of our Weasel model. For the training process, Weasel relies on the excellent PyTorch Lightning Trainer. It provides tons of options and features to optimize the training process, but the defaults below should give you reasonable results. Keep in mind that you are fine-tuning a full-blown transformer model, albeit a small one.
[ ]:
import pytorch_lightning as pl
# instantiate the pytorch-lightning trainer
trainer = pl.Trainer(
gpus=0, # >= 1 to use GPU(s)
max_epochs=2,
logger=None,
callbacks=[pl.callbacks.ModelCheckpoint(monitor="Val/accuracy", mode="max")]
)
# fit the model end-to-end
trainer.fit(
model=weasel,
train_dataloaders=train_loader,
val_dataloaders=test_loader
)
After the training we can call the Trainer.test
method to check the final performance. The model should have achieved an accuracy of around 0.94.
[ ]:
trainer.test(dataloaders=test_loader) # List of test metrics
To use the model for inference, you can either use its predict method:
[ ]:
# Example text for the inference
text = "In my head this is like 2 years ago.. Time FLIES"
# Get predictions for the example text
predicted_probs, predicted_label = weasel.predict(
tokenizer(text, return_tensors="pt")
)
# Map predicted int to label
weak_labels.int2label[int(predicted_label)] # HAM
Or you can instantiate one of the popular transformers pipelines, providing directly the end-model and the tokenizer:
[ ]:
from transformers import pipeline
# modify the id2label mapping of the model
weasel.end_model.model.config.id2label = weak_labels.int2label
# create transformers pipeline
classifier = pipeline("text-classification", model=weasel.end_model.model, tokenizer=tokenizer)
# use pipeline for predictions
classifier(text) # [{'label': 'HAM', 'score': 0.6110987663269043}]