Source code for libmultilabel.nn.networks.kim_cnn

import torch
import torch.nn as nn

from .modules import Embedding, CNNEncoder


[docs]class KimCNN(nn.Module): """KimCNN Args: embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). num_classes (int): Total number of classes. filter_sizes (list): The size of convolutional filters. num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 128. embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. post_encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0. activation (str): Activation function to be used. Defaults to 'relu'. """ def __init__( self, embed_vecs, num_classes, filter_sizes=None, num_filter_per_size=128, embed_dropout=0.2, post_encoder_dropout=0, activation="relu", ): super(KimCNN, self).__init__() self.embedding = Embedding(embed_vecs, dropout=embed_dropout) self.encoder = CNNEncoder( embed_vecs.shape[1], filter_sizes, num_filter_per_size, activation, post_encoder_dropout, num_pool=1 ) conv_output_size = num_filter_per_size * len(filter_sizes) self.linear = nn.Linear(conv_output_size, num_classes) def forward(self, input): x = self.embedding(input["text"]) # (batch_size, length, embed_dim) x = self.encoder(x) # (batch_size, num_filter, 1) x = torch.squeeze(x, 2) # (batch_size, num_filter) x = self.linear(x) # (batch_size, num_classes) return {"logits": x}