Source code for libmultilabel.nn.networks.labelwise_attention_networks

from abc import ABC, abstractmethod

import torch.nn as nn

from .modules import (
    Embedding,
    GRUEncoder,
    LSTMEncoder,
    CNNEncoder,
    LabelwiseAttention,
    LabelwiseMultiHeadAttention,
    LabelwiseLinearOutput,
    PartialLabelwiseAttention,
    MultilayerLinearOutput,
)


class LabelwiseAttentionNetwork(ABC, nn.Module):
    """Base class for Labelwise Attention Network

    Args:
        embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
        num_classes (int): Total number of classes.
        embed_dropout (float): The dropout rate of the word embedding.
        encoder_dropout (float): The dropout rate of the encoder.
        post_encoder_dropout (float): The dropout rate of the dropout layer after the encoder.
        hidden_dim (int): The output dimension of the encoder.
    """

    def __init__(self, embed_vecs, num_classes, embed_dropout, encoder_dropout, post_encoder_dropout, hidden_dim):
        super(LabelwiseAttentionNetwork, self).__init__()
        self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
        self.encoder = self._get_encoder(embed_vecs.shape[1], encoder_dropout, post_encoder_dropout)
        self.attention = self._get_attention()
        self.output = LabelwiseLinearOutput(hidden_dim, num_classes)

    @abstractmethod
    def forward(self, input):
        raise NotImplementedError

    @abstractmethod
    def _get_encoder(self, input_size, encoder_dropout, post_encoder_dropout):
        raise NotImplementedError

    @abstractmethod
    def _get_attention(self):
        raise NotImplementedError


class RNNLWAN(LabelwiseAttentionNetwork):
    """Base class for RNN Labelwise Attention Network"""

    def forward(self, input):
        # (batch_size, sequence_length, embed_dim)
        x = self.embedding(input["text"])
        # (batch_size, sequence_length, hidden_dim)
        x = self.encoder(x, input["length"])
        x, _ = self.attention(x)  # (batch_size, num_classes, hidden_dim)
        x = self.output(x)  # (batch_size, num_classes)
        return {"logits": x}


[docs]class BiGRULWAN(RNNLWAN): """BiGRU Labelwise Attention Network Args: embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). num_classes (int): Total number of classes. rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the GRU network is set to rnn_dim//2. Defaults to 512. rnn_layers (int): The number of recurrent layers. Defaults to 1. embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. encoder_dropout (float): The dropout rate of the encoder. Defaults to 0. post_encoder_dropout (float): The dropout rate of the dropout layer after the encoder. Defaults to 0. """ def __init__( self, embed_vecs, num_classes, rnn_dim=512, rnn_layers=1, embed_dropout=0.2, encoder_dropout=0, post_encoder_dropout=0, ): self.num_classes = num_classes self.rnn_dim = rnn_dim self.rnn_layers = rnn_layers super(BiGRULWAN, self).__init__( embed_vecs, num_classes, embed_dropout, encoder_dropout, post_encoder_dropout, rnn_dim, ) def _get_encoder(self, input_size, encoder_dropout, post_encoder_dropout): assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even.""" return GRUEncoder( input_size, self.rnn_dim // 2, self.rnn_layers, encoder_dropout, post_encoder_dropout, ) def _get_attention(self): return LabelwiseAttention(self.rnn_dim, self.num_classes)
[docs]class BiLSTMLWAN(RNNLWAN): """BiLSTM Labelwise Attention Network Args: embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). num_classes (int): Total number of classes. rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the LSTM network is set to rnn_dim//2. Defaults to 512. rnn_layers (int): The number of recurrent layers. Defaults to 1. embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. encoder_dropout (float): The dropout rate of the encoder. Defaults to 0. post_encoder_dropout (float): The dropout rate of the dropout layer after the encoder. Defaults to 0. """ def __init__( self, embed_vecs, num_classes, rnn_dim=512, rnn_layers=1, embed_dropout=0.2, encoder_dropout=0, post_encoder_dropout=0, ): self.num_classes = num_classes self.rnn_dim = rnn_dim self.rnn_layers = rnn_layers super(BiLSTMLWAN, self).__init__( embed_vecs, num_classes, embed_dropout, encoder_dropout, post_encoder_dropout, rnn_dim, ) def _get_encoder(self, input_size, encoder_dropout, post_encoder_dropout): assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even.""" return LSTMEncoder(input_size, self.rnn_dim // 2, self.rnn_layers, encoder_dropout, post_encoder_dropout) def _get_attention(self): return LabelwiseAttention(self.rnn_dim, self.num_classes)
[docs]class BiLSTMLWMHAN(LabelwiseAttentionNetwork): """BiLSTM Labelwise Multihead Attention Network Args: embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). num_classes (int): Total number of classes. rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the LSTM network is set to rnn_dim//2. Defaults to 512. rnn_layers (int): The number of recurrent layers. Defaults to 1. embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. encoder_dropout (float): The dropout rate of the encoder. Defaults to 0. post_encoder_dropout (float): The dropout rate of the dropout layer after the encoder. Defaults to 0. num_heads (int): The number of parallel attention heads. Defaults to 8. labelwise_attention_dropout (float): The dropout rate for the attention. Defaults to 0. """ def __init__( self, embed_vecs, num_classes, rnn_dim=512, rnn_layers=1, embed_dropout=0.2, encoder_dropout=0, post_encoder_dropout=0, num_heads=8, labelwise_attention_dropout=0, ): self.num_classes = num_classes self.rnn_dim = rnn_dim self.rnn_layers = rnn_layers self.num_heads = num_heads self.labelwise_attention_dropout = labelwise_attention_dropout super(BiLSTMLWMHAN, self).__init__( embed_vecs, num_classes, embed_dropout, encoder_dropout, post_encoder_dropout, rnn_dim, ) def _get_encoder(self, input_size, encoder_dropout, post_encoder_dropout): assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even.""" return LSTMEncoder(input_size, self.rnn_dim // 2, self.rnn_layers, encoder_dropout, post_encoder_dropout) def _get_attention(self): return LabelwiseMultiHeadAttention( self.rnn_dim, self.num_classes, self.num_heads, self.labelwise_attention_dropout ) def forward(self, input): # (batch_size, sequence_length, embed_dim) x = self.embedding(input["text"]) # (batch_size, sequence_length, hidden_dim) x = self.encoder(x, input["length"]) # (batch_size, num_classes, hidden_dim) x, _ = self.attention(x, attention_mask=input["text"] == 0) x = self.output(x) # (batch_size, num_classes) return {"logits": x}
[docs]class CNNLWAN(LabelwiseAttentionNetwork): """CNN Labelwise Attention Network Args: embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). num_classes (int): Total number of classes. filter_sizes (list): Size of convolutional filters. num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 50. embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. post_encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0. activation (str): Activation function to be used. Defaults to 'tanh'. """ def __init__( self, embed_vecs, num_classes, filter_sizes=None, num_filter_per_size=50, embed_dropout=0.2, post_encoder_dropout=0, activation="tanh", ): self.num_classes = num_classes self.filter_sizes = filter_sizes self.num_filter_per_size = num_filter_per_size self.activation = activation self.hidden_dim = num_filter_per_size * len(filter_sizes) # encoder dropout is unused for CNN, we pass 0 to satisfy LabelwiseAttentionNetwork API super(CNNLWAN, self).__init__(embed_vecs, num_classes, embed_dropout, 0, post_encoder_dropout, self.hidden_dim) def _get_encoder(self, input_size, encoder_dropout, post_encoder_dropout): # encoder dropout is unused for CNN, we accept it to satisfy LabelwiseAttentionNetwork API return CNNEncoder( input_size, self.filter_sizes, self.num_filter_per_size, self.activation, post_encoder_dropout, channel_last=True, ) def _get_attention(self): return LabelwiseAttention(self.hidden_dim, self.num_classes) def forward(self, input): # (batch_size, sequence_length, embed_dim) x = self.embedding(input["text"]) x = self.encoder(x) # (batch_size, sequence_length, hidden_dim) x, _ = self.attention(x) # (batch_size, num_classes, hidden_dim) x = self.output(x) # (batch_size, num_classes) return {"logits": x}
[docs]class AttentionXML_0(nn.Module): def __init__( self, embed_vecs, num_classes: int, rnn_dim: int, linear_size: list, freeze_embed_training: bool = False, rnn_layers: int = 1, embed_dropout: float = 0.2, encoder_dropout: float = 0, post_encoder_dropout: float = 0.5, ): super().__init__() self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout) self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout) self.attention = LabelwiseAttention(rnn_dim, num_classes) self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) def forward(self, inputs): x = inputs["text"] # the index of padding is 0 masks = x != 0 lengths = masks.sum(dim=1) masks = masks[:, : lengths.max()] x = self.embedding(x)[:, : lengths.max()] # batch_size, length, embedding_size x = self.encoder(x, lengths) # batch_size, length, hidden_size x, _ = self.attention(x) # batch_size, num_classes, hidden_size x = self.output(x) # batch_size, num_classes return {"logits": x}
[docs]class AttentionXML_1(nn.Module): def __init__( self, embed_vecs, num_classes: int, rnn_dim: int, linear_size: list, freeze_embed_training: bool = False, rnn_layers: int = 1, embed_dropout: float = 0.2, encoder_dropout: float = 0, post_encoder_dropout: float = 0.5, ): super().__init__() self.embedding = Embedding(embed_vecs, freeze=freeze_embed_training, dropout=embed_dropout) self.encoder = LSTMEncoder(embed_vecs.shape[1], rnn_dim // 2, rnn_layers, encoder_dropout, post_encoder_dropout) self.attention = PartialLabelwiseAttention(rnn_dim, num_classes) self.output = MultilayerLinearOutput([rnn_dim] + linear_size, 1) def forward(self, inputs): x = inputs["text"] labels_selected = inputs["labels_selected"] # the index of padding is 0 masks = x != 0 lengths = masks.sum(dim=1) masks = masks[:, : lengths.max()] x = self.embedding(x)[:, : lengths.max()] # batch_size, length, embedding_size x = self.encoder(x, lengths) # batch_size, length, hidden_size x, _ = self.attention(x, labels_selected) # batch_size, sample_size, hidden_size x = self.output(x) # batch_size, sample_size return {"logits": x}