Note
Go to the end to download the full example code.
KimCNN Model for Multi-label Classification
This step-by-step example shows how to train and test a KimCNN model via LibMultiLabel.
Import the libraries
Please add the following code to your python3 script.
from libmultilabel.nn.data_utils import *
from libmultilabel.nn.nn_utils import *
Setup device
If you need to reproduce the results, please use the function set_seed
.
For example, you will get the same result as you always use the seed 1337
.
For initial a hardware device, please use init_device
to assign the hardware device that you want to use.
set_seed(1337)
device = init_device() # use gpu by default
Load and tokenize data
To run KimCNN, LibMultiLabel tokenizes documents and uses an embedding vector for each word.
Thus, tokenize_text=True
is set.
We choose glove.6B.300d
from torchtext as embedding vectors.
datasets = load_datasets("data/rcv1/train.txt", "data/rcv1/test.txt", tokenize_text=True)
classes = load_or_build_label(datasets)
word_dict, embed_vecs = load_or_build_text_dict(dataset=datasets["train"], embed_file="glove.6B.300d")
Initialize a model
We consider the following settings for the KimCNN model.
model_name = "KimCNN"
network_config = {
"embed_dropout": 0.2,
"post_encoder_dropout": 0.2,
"filter_sizes": [2, 4, 8],
"num_filter_per_size": 128,
}
learning_rate = 0.0003
model = init_model(
model_name=model_name,
network_config=network_config,
classes=classes,
word_dict=word_dict,
embed_vecs=embed_vecs,
learning_rate=learning_rate,
monitor_metrics=["Micro-F1", "Macro-F1", "P@1", "P@3", "P@5"],
)
model_name
leadsinit_model
function to find a network model.network_config
contains the configurations of a network model.classes
is the label set of the data.init_weight
,word_dict
andembed_vecs
are not used on a bert-base model, so we can ignore them.moniter_metrics
includes metrics you would like to track.
Initialize a trainer
We use the function init_trainer
to initialize a trainer.
trainer = init_trainer(checkpoint_dir="runs/NN-example", epochs=15, val_metric="P@5")
In this example, checkpoint_dir
is the place we save the best and the last models during the training. Furthermore, we set the number of training loops by epochs=15
, and the validation metric by val_metric='P@5'
.
Create data loaders
In most cases, we do not load a full set due to the hardware limitation. Therefore, a data loader can load a batch of samples each time.
loaders = dict()
for split in ["train", "val", "test"]:
loaders[split] = get_dataset_loader(
data=datasets[split],
classes=classes,
device=device,
max_seq_length=512,
batch_size=8,
shuffle=True if split == "train" else False,
word_dict=word_dict,
)
This example loads three loaders, and the batch size is set by batch_size=8
. Other variables can be checked in here.
Train and test a model
The bert model training process can be started via
trainer.fit(model, loaders["train"], loaders["val"])
After the training process is finished, we can then run the test process by
trainer.test(model, dataloaders=loaders["test"])
The test results should be similar to:
{
'Macro-F1': 0.48948464335831743,
'Micro-F1': 0.7769773602485657,
'P@1': 0.9471677541732788,
'P@3': 0.7772253751754761,
'P@5': 0.5449321269989014,
}