import logging
import os
import lightning as L
import torch
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from ..nn import networks
from ..nn.model import Model
[docs]def init_device(use_cpu=False):
"""Initialize device to CPU if `use_cpu` is set to True otherwise GPU.
Args:
use_cpu (bool, optional): Whether to use CPU or not. Defaults to False.
Returns:
torch.device: One of cuda or cpu.
"""
if not use_cpu and torch.cuda.is_available():
# Set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
# https://docs.nvidia.com/cuda/cublas/index.html
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
device = torch.device("cuda")
else:
device = torch.device("cpu")
# https://github.com/pytorch/pytorch/issues/11201
torch.multiprocessing.set_sharing_strategy("file_system")
logging.info(f"Using device: {device}")
return device
[docs]def init_model(
model_name,
network_config,
classes,
word_dict=None,
embed_vecs=None,
init_weight=None,
log_path=None,
learning_rate=0.0001,
optimizer="adam",
momentum=0.9,
weight_decay=0,
lr_scheduler=None,
scheduler_config=None,
val_metric=None,
metric_threshold=0.5,
monitor_metrics=None,
multiclass=False,
loss_function="binary_cross_entropy_with_logits",
silent=False,
save_k_predictions=0,
):
"""Initialize a `Model` class for initializing and training a neural network.
Args:
model_name (str): Model to be used such as KimCNN.
network_config (dict): Configuration for defining the network.
classes (list): List of class names.
word_dict (torchtext.vocab.Vocab, optional): A vocab object for word tokenizer to
map tokens to indices. Defaults to None.
embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape
(vocab_size, embed_dim). Defaults to None.
init_weight (str): Weight initialization method from `torch.nn.init`.
For example, the `init_weight` of `torch.nn.init.kaiming_uniform_`
is `kaiming_uniform`. Defaults to None.
log_path (str): Path to a directory holding the log files and models.
learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
weight_decay (int, optional): Weight decay factor. Defaults to 0.
lr_scheduler (str, optional): Name of the learning rate scheduler. Defaults to None.
scheduler_config (dict, optional): The configuration for learning rate scheduler. Defaults to None.
val_metric (str, optional): The metric to select the best model for testing. Used by some of the schedulers. Defaults to None.
metric_threshold (float, optional): The decision value threshold over which a label is predicted as positive. Defaults to 0.5.
monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
multiclass (bool, optional): Enable multiclass mode. Defaults to False.
silent (bool, optional): Enable silent mode. Defaults to False.
loss_function (str, optional): Loss function name (i.e., binary_cross_entropy_with_logits,
cross_entropy). Defaults to 'binary_cross_entropy_with_logits'.
save_k_predictions (int, optional): Save top k predictions on test set. Defaults to 0.
Returns:
Model: A class that implements `MultiLabelModel` for initializing and training a neural network.
"""
try:
network = getattr(networks, model_name)(embed_vecs=embed_vecs, num_classes=len(classes), **dict(network_config))
except:
raise AttributeError(f"Failed to initialize {model_name}.")
if init_weight is not None:
init_weight = networks.get_init_weight_func(init_weight=init_weight)
network.apply(init_weight)
model = Model(
classes=classes,
word_dict=word_dict,
network=network,
log_path=log_path,
learning_rate=learning_rate,
optimizer=optimizer,
momentum=momentum,
weight_decay=weight_decay,
lr_scheduler=lr_scheduler,
scheduler_config=scheduler_config,
val_metric=val_metric,
metric_threshold=metric_threshold,
monitor_metrics=monitor_metrics,
multiclass=multiclass,
loss_function=loss_function,
silent=silent,
save_k_predictions=save_k_predictions,
)
return model
[docs]def init_trainer(
checkpoint_dir,
epochs=10000,
patience=5,
early_stopping_metric="P@1",
val_metric="P@1",
silent=False,
use_cpu=False,
limit_train_batches=1.0,
limit_val_batches=1.0,
limit_test_batches=1.0,
save_checkpoints=True,
):
"""Initialize a torch lightning trainer.
Args:
checkpoint_dir (str): Directory for saving models and log.
epochs (int): Number of epochs to train. Defaults to 10000.
patience (int): Number of epochs to wait for improvement before early stopping. Defaults to 5.
early_stopping_metric (str): The metric to monitor for early stopping. Defaults to 'P@1'.
val_metric (str): The metric to select the best model for testing. Defaults to 'P@1'.
silent (bool): Enable silent mode. Defaults to False.
use_cpu (bool): Disable CUDA. Defaults to False.
limit_train_batches (Union[int, float]): Percentage of training dataset to use. Defaults to 1.0.
limit_val_batches (Union[int, float]): Percentage of validation dataset to use. Defaults to 1.0.
limit_test_batches (Union[int, float]): Percentage of test dataset to use. Defaults to 1.0.
save_checkpoints (bool): Whether to save the last and the best checkpoint or not. Defaults to True.
Returns:
lightning.trainer: A torch lightning trainer.
"""
# The value of `mode` equals to 'min' only when the metric is 'Loss'
# because now for other supported metrics such as F1 or Precision, we maximize them in the training process.
# But if in the future, we further support other metrics that need to be minimized,
# we may need a dictionary that records a metric-mode mapping for a better practice.
# Set strict to False to prevent EarlyStopping from crashing the training if no validation data are provided
early_stopping_callback = EarlyStopping(
patience=patience,
monitor=early_stopping_metric,
mode="min" if early_stopping_metric == "Loss" else "max",
strict=False,
)
callbacks = [early_stopping_callback]
if save_checkpoints:
callbacks += [
ModelCheckpoint(
dirpath=checkpoint_dir,
filename="best_model",
save_last=True,
save_top_k=1,
monitor=val_metric,
mode="min" if val_metric == "Loss" else "max",
)
]
trainer = L.Trainer(
logger=False,
num_sanity_val_steps=0,
accelerator="cpu" if use_cpu else "gpu",
devices="auto" if use_cpu else 1,
enable_progress_bar=False if silent else True,
max_epochs=epochs,
callbacks=callbacks,
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
limit_test_batches=limit_test_batches,
deterministic="warn",
gradient_clip_val=0.5,
gradient_clip_algorithm="value",
)
return trainer
[docs]def set_seed(seed):
"""Set seeds for numpy and pytorch.
Args:
seed (int): Random seed.
"""
if seed is not None:
if seed >= 0:
seed_everything(seed=seed, workers=True)
else:
logging.warning("the random seed should be a non-negative integer")