Neural Network API
The neural network module libmultilabel.nn
contains three methods.
Method libmultilabel.nn.networks
is a collection of classes that defines the neural networks. The other two methods, libmultilabel.nn.data_utils
and libmultilabel.nn.nn_utils
, are utilities for processing data and training a neural network model.
libmultilabel.nn.data_utils
libmultilabel.nn.nn_utils
- libmultilabel.nn.nn_utils.init_device(use_cpu=False)[source]
Initialize device to CPU if use_cpu is set to True otherwise GPU.
- Parameters
use_cpu (bool, optional) – Whether to use CPU or not. Defaults to False.
- Returns
One of cuda or cpu.
- Return type
torch.device
- libmultilabel.nn.nn_utils.init_model(model_name, network_config, classes, word_dict=None, embed_vecs=None, init_weight=None, log_path=None, learning_rate=0.0001, optimizer='adam', momentum=0.9, weight_decay=0, lr_scheduler=None, scheduler_config=None, val_metric=None, metric_threshold=0.5, monitor_metrics=None, multiclass=False, loss_function='binary_cross_entropy_with_logits', silent=False, save_k_predictions=0)[source]
Initialize a Model class for initializing and training a neural network.
- Parameters
model_name (str) – Model to be used such as KimCNN.
network_config (dict) – Configuration for defining the network.
classes (list) – List of class names.
word_dict (torchtext.vocab.Vocab, optional) – A vocab object for word tokenizer to map tokens to indices. Defaults to None.
embed_vecs (torch.Tensor, optional) – The pre-trained word vectors of shape (vocab_size, embed_dim). Defaults to None.
init_weight (str) – Weight initialization method from torch.nn.init. For example, the init_weight of torch.nn.init.kaiming_uniform_ is kaiming_uniform. Defaults to None.
log_path (str) – Path to a directory holding the log files and models.
learning_rate (float, optional) – Learning rate for optimizer. Defaults to 0.0001.
optimizer (str, optional) – Optimizer name (i.e., sgd, adam, or adamw). Defaults to ‘adam’.
momentum (float, optional) – Momentum factor for SGD only. Defaults to 0.9.
weight_decay (int, optional) – Weight decay factor. Defaults to 0.
lr_scheduler (str, optional) – Name of the learning rate scheduler. Defaults to None.
scheduler_config (dict, optional) – The configuration for learning rate scheduler. Defaults to None.
val_metric (str, optional) – The metric to select the best model for testing. Used by some of the schedulers. Defaults to None.
metric_threshold (float, optional) – The decision value threshold over which a label is predicted as positive. Defaults to 0.5.
monitor_metrics (list, optional) – Metrics to monitor while validating. Defaults to None.
multiclass (bool, optional) – Enable multiclass mode. Defaults to False.
silent (bool, optional) – Enable silent mode. Defaults to False.
loss_function (str, optional) – Loss function name (i.e., binary_cross_entropy_with_logits, cross_entropy). Defaults to ‘binary_cross_entropy_with_logits’.
save_k_predictions (int, optional) – Save top k predictions on test set. Defaults to 0.
- Returns
A class that implements MultiLabelModel for initializing and training a neural network.
- Return type
Model
- libmultilabel.nn.nn_utils.init_trainer(checkpoint_dir, epochs=10000, patience=5, early_stopping_metric='P@1', val_metric='P@1', silent=False, use_cpu=False, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, save_checkpoints=True)[source]
Initialize a torch lightning trainer.
- Parameters
checkpoint_dir (str) – Directory for saving models and log.
epochs (int) – Number of epochs to train. Defaults to 10000.
patience (int) – Number of epochs to wait for improvement before early stopping. Defaults to 5.
early_stopping_metric (str) – The metric to monitor for early stopping. Defaults to ‘P@1’.
val_metric (str) – The metric to select the best model for testing. Defaults to ‘P@1’.
silent (bool) – Enable silent mode. Defaults to False.
use_cpu (bool) – Disable CUDA. Defaults to False.
limit_train_batches (Union[int, float]) – Percentage of training dataset to use. Defaults to 1.0.
limit_val_batches (Union[int, float]) – Percentage of validation dataset to use. Defaults to 1.0.
limit_test_batches (Union[int, float]) – Percentage of test dataset to use. Defaults to 1.0.
save_checkpoints (bool) – Whether to save the last and the best checkpoint or not. Defaults to True.
- Returns
A torch lightning trainer.
- Return type
lightning.trainer