.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/plot_KimCNN_quickstart.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_plot_KimCNN_quickstart.py: KimCNN Model for Multi-label Classification =========================================== This step-by-step example shows how to train and test a KimCNN model via LibMultiLabel. Import the libraries ---------------------------- Please add the following code to your python3 script. .. GENERATED FROM PYTHON SOURCE LINES 13-17 .. code-block:: Python from libmultilabel.nn.data_utils import * from libmultilabel.nn.nn_utils import * .. GENERATED FROM PYTHON SOURCE LINES 18-24 Setup device -------------------- If you need to reproduce the results, please use the function ``set_seed``. For example, you will get the same result as you always use the seed ``1337``. For initial a hardware device, please use ``init_device`` to assign the hardware device that you want to use. .. GENERATED FROM PYTHON SOURCE LINES 24-28 .. code-block:: Python set_seed(1337) device = init_device() # use gpu by default .. GENERATED FROM PYTHON SOURCE LINES 29-36 Load and tokenize data ------------------------------------------ To run KimCNN, LibMultiLabel tokenizes documents and uses an embedding vector for each word. Thus, ``tokenize_text=True`` is set. We choose ``glove.6B.300d`` from torchtext as embedding vectors. .. GENERATED FROM PYTHON SOURCE LINES 36-41 .. code-block:: Python datasets = load_datasets("data/rcv1/train.txt", "data/rcv1/test.txt", tokenize_text=True) classes = load_or_build_label(datasets) word_dict, embed_vecs = load_or_build_text_dict(dataset=datasets["train"], embed_file="glove.6B.300d") .. GENERATED FROM PYTHON SOURCE LINES 42-46 Initialize a model -------------------------- We consider the following settings for the KimCNN model. .. GENERATED FROM PYTHON SOURCE LINES 46-65 .. code-block:: Python model_name = "KimCNN" network_config = { "embed_dropout": 0.2, "post_encoder_dropout": 0.2, "filter_sizes": [2, 4, 8], "num_filter_per_size": 128, } learning_rate = 0.0003 model = init_model( model_name=model_name, network_config=network_config, classes=classes, word_dict=word_dict, embed_vecs=embed_vecs, learning_rate=learning_rate, monitor_metrics=["Micro-F1", "Macro-F1", "P@1", "P@3", "P@5"], ) .. GENERATED FROM PYTHON SOURCE LINES 66-77 * ``model_name`` leads ``init_model`` function to find a network model. * ``network_config`` contains the configurations of a network model. * ``classes`` is the label set of the data. * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them. * ``moniter_metrics`` includes metrics you would like to track. Initialize a trainer ---------------------------- We use the function ``init_trainer`` to initialize a trainer. .. GENERATED FROM PYTHON SOURCE LINES 77-80 .. code-block:: Python trainer = init_trainer(checkpoint_dir="runs/NN-example", epochs=15, val_metric="P@5") .. GENERATED FROM PYTHON SOURCE LINES 81-88 In this example, ``checkpoint_dir`` is the place we save the best and the last models during the training. Furthermore, we set the number of training loops by ``epochs=15``, and the validation metric by ``val_metric='P@5'``. Create data loaders --------------------------- In most cases, we do not load a full set due to the hardware limitation. Therefore, a data loader can load a batch of samples each time. .. GENERATED FROM PYTHON SOURCE LINES 88-101 .. code-block:: Python loaders = dict() for split in ["train", "val", "test"]: loaders[split] = get_dataset_loader( data=datasets[split], classes=classes, device=device, max_seq_length=512, batch_size=8, shuffle=True if split == "train" else False, word_dict=word_dict, ) .. GENERATED FROM PYTHON SOURCE LINES 102-108 This example loads three loaders, and the batch size is set by ``batch_size=8``. Other variables can be checked in `here <../api/nn.html#libmultilabel.nn.data_utils.get_dataset_loader>`_. Train and test a model ------------------------------ The bert model training process can be started via .. GENERATED FROM PYTHON SOURCE LINES 108-111 .. code-block:: Python trainer.fit(model, loaders["train"], loaders["val"]) .. GENERATED FROM PYTHON SOURCE LINES 112-113 After the training process is finished, we can then run the test process by .. GENERATED FROM PYTHON SOURCE LINES 113-116 .. code-block:: Python trainer.test(model, dataloaders=loaders["test"]) .. GENERATED FROM PYTHON SOURCE LINES 117-126 The test results should be similar to:: { 'Macro-F1': 0.48948464335831743, 'Micro-F1': 0.7769773602485657, 'P@1': 0.9471677541732788, 'P@3': 0.7772253751754761, 'P@5': 0.5449321269989014, } .. _sphx_glr_download_auto_examples_plot_KimCNN_quickstart.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_KimCNN_quickstart.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_KimCNN_quickstart.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_