import torch
import torch.nn as nn
import torch.nn.functional as F
from .modules import Embedding, CNNEncoder
[docs]class XMLCNN(nn.Module):
"""XML-CNN
Args:
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
num_classes (int): Total number of classes.
embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
post_encoder_dropout (float): The dropout rate of the hidden layer output. Defaults to 0.
filter_sizes (list): Size of convolutional filters.
hidden_dim (int): Dimension of the hidden layer. Defaults to 512.
num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 256.
num_pool (int): The number of pool for dynamic max-pooling. Defaults to 2.
activation (str): Activation function to be used. Defaults to 'relu'.
"""
def __init__(
self,
embed_vecs,
num_classes,
embed_dropout=0.2,
post_encoder_dropout=0,
filter_sizes=None,
hidden_dim=512,
num_filter_per_size=256,
num_pool=2,
activation="relu",
):
super(XMLCNN, self).__init__()
self.embedding = Embedding(embed_vecs, dropout=embed_dropout)
self.encoder = CNNEncoder(embed_vecs.shape[1], filter_sizes, num_filter_per_size, activation, num_pool=num_pool)
total_output_size = len(filter_sizes) * num_filter_per_size * num_pool
self.linear1 = nn.Linear(total_output_size, hidden_dim)
self.activation = getattr(torch, activation, getattr(F, activation))
self.post_encoder_dropout = nn.Dropout(post_encoder_dropout)
self.linear2 = nn.Linear(hidden_dim, num_classes)
def forward(self, input):
x = self.embedding(input["text"]) # (batch_size, length, embed_dim)
x = self.encoder(x) # (batch_size, num_filter, num_pool)
x = x.view(x.shape[0], -1) # (batch_size, num_filter * num_pool)
x = self.activation(self.linear1(x))
x = self.post_encoder_dropout(x)
x = self.linear2(x)
return {"logits": x}