🐭 Weakly supervised NER with skweak
¶
This tutorial will walk you through the process of using Rubrix to improve weak supervision and data programming workflows with the skweak library.
Using skweak and spaCy, we define heuristic labeling functions for the CoNLL 2003 dataset.
We combine those with a pretrained NER model and use an aggregation model from skweak to obtain noisy NER annotations.
We then log the documents to Rubrix and visualize the results via its web app.
With the noisy labels, we fine-tune a spaCy NER model.
Adding labeling functions from gazetteers to our aggregation model, we revise the updated noisy annotation with Rubrix, and retrain the spaCy model.
Instead of a spaCy model, we fine-tune a transformers model with the help of the simpletransformers library.
Introduction¶
Our goal is to show you how you can incorporate Rubrix into data programming workflows to programatically build training data with a human-in-the-loop approach. We will use the skweak library.
What is weak supervision? and skweak?¶
Weak supervision is a branch of machine learning based on getting lower quality labels more efficiently. We can achieve this by using skweak, a library for programmatically building and managing training datasets without manual labeling.
This tutorial¶
In this tutorial, we bring content from the Quick Start Named-Entity Recognition and the Step-by-step NER tutorials from skweak’s documentation and show you how to extend weak supervision workflows with Rubrix.
We will take records from the CoNLL 2003 dataset and build our own annotations with skweak
. Then we are going to evaluate NER models trained on our annotations on the standard development set of CoNLL 2003.
Setup¶
Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.
If you are new to Rubrix, check out the ⭐ Github repository.
If you have not installed and launched Rubrix yet, check the Setup and Installation guide.
For this tutorial we also need some third party libraries that can be installed via pip:
[5]:
%pip install skweak spacy simpletransformers -qqq
!python -m spacy download en_core_web_sm -qqq
!python -m spacy download en_core_web_md -qqq
Named Entity Recognition with skweak and Rubrix¶
Rubrix allows you to log and track data for different NLP tasks (such as Token Classification
or Text Classification
).
In this tutorial, we will use the English portion of the CoNLL 2003 dataset, a standard Named Entity Recognition benchmark.
The dataset¶
In this tutorial we’ll be using skweak’s data programming methods for programatically building a training set with the help of Rubrix for analizing and reviewing data. We’ll then train a model with this training set.
Although the gold labels for the training set of CoNLL 2003 are already known, we will purposefully ignore them, as our goal in this tutorial is to build our own annotations and see how well they perform on the development set.
We will load the CoNLL 2003 dataset with the help of the datasets
library.
[ ]:
from datasets import load_dataset
dataset = load_dataset("conll2003")
Preprocessing¶
Next, we simplify the tagset by replacing numbers with tags and removing the BIO encoding.
[ ]:
tag_set = {'O': 0, 'B-PER': 1, 'I-PER': 2, 'B-ORG': 3, 'I-ORG': 4, 'B-LOC': 5, 'I-LOC': 6, 'B-MISC': 7, 'I-MISC': 8}
tag_set = { v:k for k,v in tag_set.items() }
def convert_ner_tags(record, tag_set=None):
record['ner_tags'] = [ tag_set[x] for x in record['ner_tags'] ]
return record
def strip_BI_tags(record):
record['ner_tags'] = [ x.lstrip('B-').lstrip('I-') for x in record['ner_tags'] ]
return record
dataset = dataset\
.map(convert_ner_tags, fn_kwargs={"tag_set": tag_set})\
.map(strip_BI_tags)
We will now convert the training and validation splits of our dataset into spaCy Doc objects.
spaCy demands strings to be given as inputs to a tokenizer. However, as our dataset is already tokenized, we bypass this restriction by using our own tokenizer and encapsulating our tokens in a class that inherits from str
.
[8]:
import spacy
from spacy.tokens import Doc
from dataclasses import dataclass
@dataclass
class Record(str):
tokens: list
def custom_tokenizer(text):
return Doc(nlp.vocab, text.tokens)
nlp = spacy.load("en_core_web_sm", disable=["ner", "lemmatizer"])
nlp.tokenizer = custom_tokenizer
training_set = [ Record(x) for x in dataset["train"]["tokens"] ]
dev_set = [ Record(x) for x in dataset["validation"]["tokens"] ]
train_docs = list(nlp.pipe(training_set))
dev_docs = list(nlp.pipe(dev_set))
The gold labels must also be added to our validation Doc
objects, so we can evaluate our model during training.
[9]:
from spacy.tokens import Span
from itertools import groupby
from dataclasses import dataclass
import copy
@dataclass
class IndexedLabel:
index: int
label: str
def annotate_labels_to_doc(doc, labels, null_label="O"):
labels = [ IndexedLabel(idx, item) for idx, item in enumerate(labels) ]
grouped_labels = [ list(group[1]) for group in groupby(labels) ]
span_objects = [ Span(doc, item[0].index, item[-1].index + 1, item[0].label) for item in grouped_labels ]
span_objects = [ span for span in span_objects if span.label_ != null_label ]
doc.set_ents(span_objects)
return doc
dev_labels = [ x for x in dataset["validation"]["ner_tags"] ]
for idx, label_sequence in enumerate(dev_labels):
dev_docs[idx] = annotate_labels_to_doc(dev_docs[idx], label_sequence)
Labeling functions¶
Labelling functions (LFs) are at the core of skweak. They take a Doc
as an input and return a list of spans with their associated labels.
In this tutorial, we will first define the LFs from the skweak tutorial and then show you how you can use Rubrix to enhance this type of weak-supervision workflow.
Heuristics¶
One simple type of labelling functions are heuristics. For instance, we can write that commercial companies may be recognized by their legal suffix (such as Corp.):
[10]:
import skweak
def company_detector_fun(doc):
for chunk in doc.noun_chunks:
if chunk[-1].lower_.rstrip(".") in {'corp', 'inc', 'ltd', 'llc', 'sa', 'ag'}:
yield chunk.start, chunk.end, "COMPANY"
# We create the labelling function by giving it a name, and a function to apply
company_detector = skweak.heuristics.FunctionAnnotator("company_detector", company_detector_fun)
We can write another example of heuristics for non-commercial organisations by looking for the occurrence of words that are quite typical of public organisations or NGOs:
[11]:
OTHER_ORG_CUE_WORDS = {"University", "Institute", "College", "Committee", "Party", "Agency",
"Union", "Association", "Organization", "Court", "Office", "National"}
def other_org_detector_fun(doc):
for chunk in doc.noun_chunks:
if any([tok.text in OTHER_ORG_CUE_WORDS for tok in chunk]):
yield chunk.start, chunk.end, "OTHER_ORG"
# We create the labelling function
other_org_detector = skweak.heuristics.FunctionAnnotator("other_org_detector", other_org_detector_fun)
NER models¶
We can also take advantage of machine learning models trained from data of related domains. Here, we will use a spacy model trained on OntoNotes 5.0 to get more named entities.
[12]:
ner = skweak.spacy.ModelAnnotator("spacy", "en_core_web_sm")
Finally, we run our annotators over the documents.
[13]:
train_docs = list(company_detector.pipe(train_docs))
train_docs = list(other_org_detector.pipe(train_docs))
train_docs = list(ner.pipe(train_docs))
Aggregation¶
Once the labelling functions have been applied, we must then aggregate their results, so that we obtain a single annotation for each document.
This can be done in skweak
through a Hidden Markov Model. Here we use a CustomHMM
class as a workaround for tokens with impossible states, as suggested in this issue.
[ ]:
from typing import Dict
import numpy as np
import skweak
class CustomHMM(skweak.aggregation.HMM):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _compute_log_likelihood(self, X: Dict[str, np.ndarray]) -> np.ndarray:
"""Computes the log likelihood for the observed sequence"""
logsum = np.float32()
for source in X:
# Include the weights to the probabilities
probs = (self.emit_probs[source] # type:ignore
** self.weights.get(source, 1))
# We compute the likelihood of each state given the observations
probs = np.dot(X[source], probs.T)
# Impossible states have a logprob of -inf
log_probs = np.ma.log(probs).filled(-np.inf)
logsum += log_probs
# We also add a constraint that the probability of a state is zero
# if no labelling functions observes it
X_all_obs = np.zeros(logsum.shape, dtype=bool) # type: ignore
for source in self.emit_counts:
if source in X:
if "O" in self.out_labels:
X_all_obs += X[source][:, :len(self.out_labels)]
else:
X_all_obs += X[source][:, 1:len(self.out_labels)+1]
logsum = np.where(X_all_obs, logsum, -100000.0)
return logsum # type: ignore
# We define the aggregation model
hmm_model = CustomHMM("hmm", ["COMPANY", "OTHER_ORG"])
# We indicate that "ORG" is an underspecified value, which may
# represent either COMPANY or OTHER_ORG
hmm_model.add_underspecified_label("ORG", ["COMPANY", "OTHER_ORG"])
# And run the estimation
train_docs = hmm_model.fit_and_aggregate(train_docs)
Visualization with Rubrix¶
We can use Rubrix to visualize the outputs of our aggregation model.
First we define a doc_logger
function that will log the predictions produced by all our annotators to Rubrix.
[15]:
from tqdm import tqdm
import rubrix as rb
def doc_logger(texts, docs, rubrix_dataset="conll_2003"):
records = []
for idx, doc in enumerate(tqdm(docs, total=len(docs))):
tokens = [token.text for token in doc]
if not tokens:
continue
if doc.spans:
for labelling_function, span_list in doc.spans.items():
entities = [
(ent.label_, ent.start_char, ent.end_char)
for ent in span_list
]
if entities:
records.append(
rb.TokenClassificationRecord(
text=texts[idx],
tokens=tokens,
prediction=entities,
prediction_agent=labelling_function,
metadata={
"doc_index": idx
}
)
)
if records:
rb.log(records=records, name=rubrix_dataset)
We choose to log the first 1000 documents of our dataset.
[ ]:
def log_to_rubrix(tokens, docs, limit=None, rubrix_dataset="my_dataset_name"):
texts = [ " ".join(x) for x in tokens ]
text_sample = []
doc_sample = []
for idx, doc in enumerate(docs):
text_sample.append(texts[idx])
doc_sample.append(doc)
if limit and idx >= limit:
break
doc_logger(
text_sample,
doc_sample,
rubrix_dataset=rubrix_dataset
)
log_to_rubrix(dataset["train"]["tokens"], train_docs, limit=1000, rubrix_dataset="conll_2003_hmm")
If we sort the Token Classification records on the conll_2003_hmm
dataset according to the Metadata.doc_index
field, we will be able to see the predictions produced for each record by our NER models, labelling functions and aggregation models.
Training a spaCy model¶
Preprocessing¶
Before we train our own model, we need to make sure that our training and development sets are using exactly the same tags.
We map the tags on our training set to the standard tags in the CoNLL 2003 dataset through a replacement dictionary.
[ ]:
def get_doc_labels(docs):
label_set = set()
for doc in docs:
for annotator, spans in doc.spans.items():
for span in spans:
label_set.update([span.label_])
for entity in doc.ents:
label_set.update([entity.label_])
return label_set
get_doc_labels(train_docs)
[18]:
replacement_dict = {'CARDINAL': 'MISC',
'COMPANY': 'ORG',
'DATE': 'MISC',
'EVENT': 'MISC',
'FAC': 'MISC',
'GPE': 'LOC',
'LANGUAGE': 'MISC',
'LAW': 'MISC',
'LOC': 'LOC',
'MONEY': 'MISC',
'NORP': 'MISC',
'ORDINAL': 'MISC',
'ORG': 'ORG',
'OTHER_ORG': 'ORG',
'PERCENT': 'MISC',
'PERSON': 'PER',
'PRODUCT': 'MISC',
'QUANTITY': 'MISC',
'TIME': 'MISC',
'WORK_OF_ART': 'MISC'}
[19]:
from spacy.tokens import Span
def annotation_standardiser(doc, replacement_dict):
for source in doc.spans:
new_spans = []
for span in doc.spans[source]:
new_label = replacement_dict.get(span.label_, None)
if "\n" in span.text:
continue
elif new_label:
new_spans.append(
Span(doc, span.start, span.end, label=new_label)
)
else:
new_spans.append(span)
doc.spans[source] = new_spans
return doc
train_docs_standard = []
for doc in train_docs:
new_doc = annotation_standardiser(doc, replacement_dict=replacement_dict)
train_docs_standard.append(new_doc)
[20]:
assert not get_doc_labels(train_docs).symmetric_difference(get_doc_labels(dev_docs))
After matching the tags, we can train our own NER model.
We choose to use the labels produced by the aggregation model as our first option, and take the labels produced by the spaCy model trained on OntoNotes 5.0 as a fallback for instances in which the HMM model failed to aggregate any tags.
[20]:
for doc in train_docs_standard:
spacy_ents = doc.spans.get("spacy", ())
hmm_ents = doc.spans.get("hmm", ())
if hmm_ents:
doc.ents = hmm_ents
else:
doc.ents = spacy_ents
We use the docbin_writer
method from the skweak
library to save our documents for training and evaluation.
[32]:
from skweak.utils import docbin_writer
docbin_writer(train_docs_standard, "/tmp/train.spacy")
docbin_writer(dev_docs, "/tmp/dev.spacy")
Write to /tmp/train.spacy...done
Write to /tmp/dev.spacy...done
Training¶
As it can be seen below, after 200 steps, or spaCy NER model was able to achieve a score of 21%.
[21]:
!spacy init config - --lang en --pipeline ner --optimize accuracy | \
spacy train - \
--training.max_steps 200 \
--paths.train /tmp/train.spacy \
--paths.dev /tmp/dev.spacy \
--initialize.vectors en_core_web_md \
--output /tmp/model
✔ Created output directory: /tmp/model
ℹ Saving to output directory: /tmp/model
ℹ Using CPU
ℹ To switch to GPU 0, use the option: --gpu-id 0
=========================== Initializing pipeline ===========================
[2022-01-03 10:31:34,283] [INFO] Set up nlp object from config
[2022-01-03 10:31:34,300] [INFO] Pipeline: ['tok2vec', 'ner']
[2022-01-03 10:31:34,305] [INFO] Created vocabulary
[2022-01-03 10:31:35,954] [INFO] Added vectors: en_core_web_md
[2022-01-03 10:31:36,121] [INFO] Finished initializing nlp object
[2022-01-03 10:32:33,750] [INFO] Initialized pipeline components: ['tok2vec', 'ner']
✔ Initialized pipeline
============================= Training pipeline =============================
ℹ Pipeline: ['tok2vec', 'ner']
ℹ Initial learn rate: 0.001
E # LOSS TOK2VEC LOSS NER ENTS_F ENTS_P ENTS_R SCORE
--- ------ ------------ -------- ------ ------ ------ ------
0 0 0.00 43.00 0.00 0.00 0.00 0.00
0 200 45.20 3765.66 21.21 23.23 19.52 0.21
✔ Saved pipeline to output directory
/tmp/model/model-last
Add Gazetteers¶
In addition to heuristics, we can also exploit labelling functions made from gazetteers. They search for the occurrences of entries, often extracted from a knowledge base.
Wikipedia¶
The database from Wikipedia is extracted from the NECKar dataset. This gazetteer is limited to wikidata objects containing a text description.
[ ]:
tries = skweak.gazetteers.extract_json_data("./data/skweak/wikidata_small_tokenised.json.gz")
wikismall_gazetteer_cased = skweak.gazetteers.GazetteerAnnotator("wikismall_cased_gazetteer", tries)
wikismall_gazetter_uncased = skweak.gazetteers.GazetteerAnnotator("wikismall_uncased_gazetteer", tries, case_sensitive=False)
Crunchbase¶
The Crunchbase gazetteer is extracted from the Open Data Map from Crunchbase, which contains lists of both organisations and (business) persons.
[ ]:
tries = skweak.gazetteers.extract_json_data("./data/skweak/crunchbase_companies.json.gz")
crunchbase_gazetteer = skweak.gazetteers.GazetteerAnnotator("crunchbase_gazetteer", tries)
Geonames¶
The geonames database contains a large list of locations, including both geopolitical entities and “natural” locations.
[ ]:
tries = skweak.gazetteers.extract_json_data("./data/skweak/geonames.json", spacy_model="en_core_web_sm")
geonames_gazetteer_cased = skweak.gazetteers.GazetteerAnnotator("geo_cased_gazetteer", tries)
geonames_gazetteer_uncased = skweak.gazetteers.GazetteerAnnotator("geo_uncased_gazetteer", tries, case_sensitive=False)
DBPedia¶
This gazeetter utilizes DBPedia to extract a list of products and brands as products.
[ ]:
tries = skweak.gazetteers.extract_json_data("./data/skweak/products.json", spacy_model="en_core_web_sm")
products_gazetteer_cased = skweak.gazetteers.GazetteerAnnotator("products_cased_gazetteer", tries)
products_gazetteer_uncased = skweak.gazetteers.GazetteerAnnotator("products_uncased_gazetteer", tries)
We combine all gazetteers into a single annotator through the CombinedAnnotator
class.
[25]:
from skweak.base import CombinedAnnotator
gazetteers = [wikismall_gazetteer_cased, crunchbase_gazetteer, geonames_gazetteer_cased, products_gazetteer_cased]
combined_gazetteer = CombinedAnnotator()
for gazetteer in gazetteers:
combined_gazetteer.add_annotator(gazetteer)
train_docs = list(combined_gazetteer.pipe(train_docs))
Aggregation¶
Besides using a HMM model, we can also aggregate the annotations of our documents using majority voting.
We map our tags to the CoNLL format, and then apply the MajorityVoter
aggregation model.
[26]:
train_docs_standard_v2 = []
for doc in train_docs:
new_doc = annotation_standardiser(doc, replacement_dict=replacement_dict)
train_docs_standard_v2.append(new_doc)
[27]:
mv = skweak.aggregation.MajorityVoter("mv", ["LOC", "MISC", "ORG", "PER"])
mv.add_underspecified_label("ENT", {"LOC", "MISC", "ORG", "PER"})
[28]:
train_docs_standard_v2 = list(mv.pipe(train_docs_standard_v2))
Visualization with Rubrix¶
We can visualize the annotations produced by our gazetteers with Rubrix.
We are able to notice that, among all our gazetteers, only wikismall_gazetteer_cased
was able to capture entities from the training data.
[ ]:
log_to_rubrix(dataset["train"]["tokens"], train_docs_standard_v2, limit=1000, rubrix_dataset="conll_2003_gazetteers")
Training a spaCy model¶
We choose to use the labels produced by the aggregation model as our first option, and take the labels produced by the spaCy model trained on OntoNotes 5.0 as a fallback for instances in which the MajorityVoter
failed to aggregate any tags.
[30]:
for doc in train_docs_standard_v2:
spacy_ents = doc.spans.get("spacy", ())
mv_ents = doc.spans.get("mv", ())
if mv_ents:
doc.ents = mv_ents
else:
doc.ents = spacy_ents
We use the docbin_writer
method from the skweak
library to save our documents for training.
Our development set has already been saved as dev.spacy
in our previous training iteration.
[ ]:
docbin_writer(train_docs_standard_v2, "/tmp/train_v2.spacy")
As it can be seen below, after adding gazetteers and using a majority voter as our aggregation model, or trained NER model was able to achieve a F1-score of 22%, which is a 1% improvement over our previous result.
For the sake of brevity, we did not present all labelling functions in the skweak
library in this tutorial. We should ideally stack several labelling functions and loop through annotation and training until we arrive at our desired results. Please refer to the Step-by-step NER tutorial and the official skweak documentation for a full overview of what is possible to
achieve with the library.
[34]:
!spacy init config - --lang en --pipeline ner --optimize accuracy | \
spacy train - \
--training.max_steps 200 \
--paths.train /tmp/train_v2.spacy \
--paths.dev /tmp/dev.spacy \
--initialize.vectors en_core_web_md \
--output /tmp/model
ℹ Saving to output directory: /tmp/model
ℹ Using CPU
=========================== Initializing pipeline ===========================
[2022-01-03 16:06:22,664] [INFO] Set up nlp object from config
[2022-01-03 16:06:22,674] [INFO] Pipeline: ['tok2vec', 'ner']
[2022-01-03 16:06:22,678] [INFO] Created vocabulary
[2022-01-03 16:06:23,774] [INFO] Added vectors: en_core_web_md
[2022-01-03 16:06:24,127] [INFO] Finished initializing nlp object
[2022-01-03 16:07:29,953] [INFO] Initialized pipeline components: ['tok2vec', 'ner']
✔ Initialized pipeline
============================= Training pipeline =============================
ℹ Pipeline: ['tok2vec', 'ner']
ℹ Initial learn rate: 0.001
E # LOSS TOK2VEC LOSS NER ENTS_F ENTS_P ENTS_R SCORE
--- ------ ------------ -------- ------ ------ ------ ------
0 0 0.00 43.00 0.00 0.00 0.00 0.00
0 200 46.47 4005.87 22.49 23.64 21.45 0.22
✔ Saved pipeline to output directory
/tmp/model/model-last
Simpletransformers¶
Rather than training our NER models in spaCy, we can also fine-tune pre-trained transformers to our annotations produced with skweak
.
Here we use simpletransformers, a library built on top of the transformers library.
Preprocessing¶
First we have to convert our spaCy Doc
objects into dataframes that can be utilized with the simpletransformers
library.
[48]:
import pandas as pd
def get_training_data_from_docs(docs):
training_df = []
for doc_idx, doc in enumerate(docs):
tokens = [ x.text for x in doc ]
label_array = ['O'] * len(tokens)
for ent in doc.ents:
for token_idx, token in enumerate(tokens):
if token_idx >= ent.start and token_idx + 1 <= ent.end:
label_array[token_idx] = ent.label_
training_df_rows = [ [doc_idx, tokens[idx], label_array[idx]] for idx in range(len(tokens)) ]
training_df.extend(training_df_rows)
training_df = pd.DataFrame(training_df, columns=["sentence_id", "words", "labels"])
return training_df
def get_evaluation_data_from_dataset(dataset, tokens_field="tokens", tags_field="ner_tags"):
tokens = dataset[tokens_field]
tags = dataset[tags_field]
index_array = []
for idx, item in enumerate(tokens):
index_array.append([idx] * len(item))
test_df_sents = [ list(zip(index_array[idx], tokens[idx], tags[idx]))
for idx, item in enumerate(tokens) ]
eval_df = []
for sentence in test_df_sents:
eval_df.extend(sentence)
eval_df = pd.DataFrame(eval_df, columns=["sentence_id", "words", "labels"])
return eval_df
train_df = get_training_data_from_docs(train_docs_standard_v2)
eval_df = get_evaluation_data_from_dataset(dataset['validation'])
Training¶
Here we fine-tune a distilbert model according to the instructions in the simpletransformers documentation.
[ ]:
# Configure the model
from simpletransformers.ner import NERModel, NERArgs
model_args = NERArgs()
model_args.train_batch_size = 16
model_args.evaluate_during_training = True
custom_labels = [ "O", "PER", "ORG", "LOC", "MISC" ]
model = NERModel(
"distilbert", "distilbert-base-cased", args=model_args, use_cuda=True, labels=custom_labels
)
# Train the model
model.train_model(train_df, eval_data=eval_df)
# Evaluate the model
result, model_outputs, preds_list = model.eval_model(eval_df)
After fine-tuning a distilbert
model, we can see that we were able to raise our F1-score to 52%.
[53]:
# Print the results
result
[53]:
{'eval_loss': 0.677833871929666,
'f1_score': 0.5242952373303349,
'precision': 0.4152671755725191,
'recall': 0.7109562186887388}