Now that almost everyone has thought about or is actively integrating AI workflows into their projects, some might ask is this all worth the cost? Many think the current economics of the AI space don't scale and that there will be upward price movement. Others still might not be comfortable with sending their data to remote services for processing. Then there is the crowd that wants to deploy models in small spaces with limited compute.
Are there ways we can deploy small models locally and run at a lower cost? Yes with Knowledge Distillation. Knowledge distillation can get a bad rap due to it's questionable use in training some Large Language Models (LLMs). But it's a perfectly valid way to transfer performance from a larger model to a smaller one. Especially when both models are yours and/or open.
This article will explore progressive distillation which is a technique to incrementally transfer knowledge from a series of larger teacher models into a smaller student.
Install txtai
and all dependencies.
pip install txtai[pipeline-train] datasets
The first step we need to do is setup up the training pipeline. We'll use the Hugging Face Training framework to build a series of models.
The following code establishes a train
method, test
method and loads the classification training data.
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from txtai.pipeline import HFTrainer, Labels
def train(teacher, student, distillation, **kwargs):
trainer = HFTrainer()
model = AutoModelForSequenceClassification.from_pretrained(student, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(student, trust_remote_code=True)
return trainer(
(model, tokenizer),
ds["train"],
columns=("sentence", "label"),
maxlength=maxlength,
teacher=teacher,
distillation=distillation,
**kwargs
)
def test(model):
labels = Labels(model, dynamic=False, trust_remote_code=True)
results = [row["label"] == labels(row["sentence"], max_length=maxlength)[0][0] for row in ds["validation"]]
print(sum(results) / len(ds["validation"]))
ds = load_dataset("nyu-mll/glue", "sst2")
maxlength = 128
We're going to build a bert-hash-femto classifier which is a extremely small 250K parameter model. This model was pretrained using the same recipe as BERT. The paper Well-Read Students Learn Better: On the Importance of Pre-training Compact Models established that small models perform better with Knowledge Distillation tasks when they are pretrained. This article is also a good source for more information on the topic.
The next series of steps will do the following:
This path goes from a 11M parameter model -> 4.4M parameter model -> 250K parameter model
test(train(
"assemblyai/bert-large-uncased-sst2",
"google/bert_uncased_L-4_H-256_A-4",
(1.0, 1.0),
learning_rate=1e-4,
num_train_epochs=5,
per_device_train_batch_size=32,
output_dir="bert-mini-sst2"
))
[transformers] [1mBertForSequenceClassification LOAD REPORT[0m from: google/bert_uncased_L-4_H-256_A-4
Key | Status |
-------------------------------------------+------------+-
cls.predictions.transform.dense.weight | UNEXPECTED |
cls.predictions.decoder.bias | UNEXPECTED |
cls.predictions.transform.LayerNorm.bias | UNEXPECTED |
cls.predictions.decoder.weight | UNEXPECTED |
cls.seq_relationship.bias | UNEXPECTED |
cls.seq_relationship.weight | UNEXPECTED |
cls.predictions.transform.dense.bias | UNEXPECTED |
cls.predictions.bias | UNEXPECTED |
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |
classifier.weight | MISSING |
classifier.bias | MISSING |
Notes:
- UNEXPECTED: can be ignored when from different task/architecture; not ok if you expect identical arch.
- MISSING: those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
0.875
test(train(
"bert-mini-sst2",
"google/bert_uncased_L-2_H-128_A-2",
(1.0, 1.0),
learning_rate=1e-4,
num_train_epochs=5,
per_device_train_batch_size=32,
output_dir="bert-tiny-sst2"
))
[transformers] [1mBertForSequenceClassification LOAD REPORT[0m from: google/bert_uncased_L-2_H-128_A-2
Key | Status |
-------------------------------------------+------------+-
cls.predictions.transform.dense.weight | UNEXPECTED |
cls.predictions.transform.LayerNorm.bias | UNEXPECTED |
cls.seq_relationship.weight | UNEXPECTED |
cls.seq_relationship.bias | UNEXPECTED |
cls.predictions.transform.dense.bias | UNEXPECTED |
cls.predictions.bias | UNEXPECTED |
cls.predictions.transform.LayerNorm.weight | UNEXPECTED |
classifier.weight | MISSING |
classifier.bias | MISSING |
Notes:
- UNEXPECTED: can be ignored when from different task/architecture; not ok if you expect identical arch.
- MISSING: those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
0.8325688073394495
test(train(
"bert-tiny-sst2",
"neuml/bert-hash-femto",
(1.0, 1.0),
learning_rate=3e-4,
num_train_epochs=5,
per_device_train_batch_size=32,
output_dir="bert-femto-sst2"
))
[transformers] [1mBertHashForSequenceClassification LOAD REPORT[0m from: neuml/bert-hash-femto
Key | Status |
-------------------------+---------+-
bert.pooler.dense.bias | MISSING |
bert.pooler.dense.weight | MISSING |
classifier.bias | MISSING |
classifier.weight | MISSING |
Notes:
- MISSING: those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
0.8084862385321101
Let's look at the performance. The 11M parameter model registered an accuracy score of 0.8750
on the SST2 dev set. Then the 4M parameter model scored 0.8326
and finally the tiny femto model scored 0.8085
. As we can see each score got progressively worse but each model has less capability. Let's train a femto model directly to compare.
test(train(
None,
"neuml/bert-hash-femto",
None,
learning_rate=3e-4,
num_train_epochs=5,
per_device_train_batch_size=32,
output_dir="bert-femto-sst2"
))
[transformers] [1mBertHashForSequenceClassification LOAD REPORT[0m from: neuml/bert-hash-femto
Key | Status |
-------------------------+---------+-
bert.pooler.dense.bias | MISSING |
bert.pooler.dense.weight | MISSING |
classifier.bias | MISSING |
classifier.weight | MISSING |
Notes:
- MISSING: those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
0.801605504587156
This scored 0.8016
a sizable shift down from the progressively distilled model. Now keep in mind the femto model only has 250K
parameters but we gave it a sizable accuracy boost within it's capabilities.
This example showed how progressive distillation can boost overall model performance, especially for tiny models. Incrementally compressing knowledge into a series of smaller subsets enabled the final model to learn more efficiently than directly training the model on the dataset without distillation. Put progressive distillation in the toolkit when working with tiny models!