from math import floor
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_
[docs]class CAML(nn.Module):
"""CAML (Convolutional Attention for Multi-Label classification)
Follows the work of Mullenbach et al. [https://aclanthology.org/N18-1100.pdf]
This class is for reproducing the results in the paper.
Use CNNLWAN instead for better modularization.
Args:
embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
num_classes (int): Total number of classes.
filter_sizes (list): Size of convolutional filters.
num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 50.
embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
"""
def __init__(
self,
embed_vecs,
num_classes,
filter_sizes=None,
num_filter_per_size=50,
embed_dropout=0.2,
):
super(CAML, self).__init__()
if not filter_sizes and len(filter_sizes) != 1:
raise ValueError(f"CAML expect 1 filter size. Got filter_sizes={filter_sizes}")
filter_size = filter_sizes[0]
self.embedding = nn.Embedding(len(embed_vecs), embed_vecs.shape[1], padding_idx=0)
self.embedding.weight.data = embed_vecs.clone()
self.embed_dropout = nn.Dropout(p=embed_dropout)
# Initialize conv layer
self.conv = nn.Conv1d(
embed_vecs.shape[1], num_filter_per_size, kernel_size=filter_size, padding=int(floor(filter_size / 2))
)
xavier_uniform_(self.conv.weight)
"""Context vectors for computing attention with
(in_features, out_features) = (num_filter_per_size, num_classes)
"""
self.Q = nn.Linear(num_filter_per_size, num_classes)
xavier_uniform_(self.Q.weight)
# Final layer: create a matrix to use for the #labels binary classifiers
self.output = nn.Linear(num_filter_per_size, num_classes)
xavier_uniform_(self.output.weight)
def forward(self, input):
# Get embeddings and apply dropout
x = self.embedding(input["text"]) # (batch_size, length, embed_dim)
x = self.embed_dropout(x)
x = x.transpose(1, 2) # (batch_size, embed_dim, length)
""" Apply convolution and nonlinearity (tanh). The shapes are:
- self.conv(x): (batch_size, num_filte_per_size, length)
- x after transposing the first and the second dimension and applying
the activation function: (batch_size, length, num_filte_per_size)
"""
Z = torch.tanh(self.conv(x).transpose(1, 2))
"""Apply per-label attention. The shapes are:
- Q.weight: (num_classes, num_filte_per_size)
- matrix product of U.weight and x: (batch_size, num_classes, length)
- alpha: (batch_size, num_classes, length)
"""
alpha = torch.softmax(self.Q.weight.matmul(Z.transpose(1, 2)), dim=2)
# Document representations are weighted sums using the attention
E = alpha.matmul(Z) # (batch_size, num_classes, num_filter_per_size)
# Compute a probability for each label
logits = self.output.weight.mul(E).sum(dim=2).add(self.output.bias) # (batch_size, num_classes)
return {"logits": logits, "attention": alpha}