Pytorch Transformer

Reason is the light and the light of life.

Jerry Su Jul 22, 2020 2 mins

pytorch transformer

illustrated transformer

import math
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
from layers.encoder import CustomEmbedding
import layers


class PositionalEncoding(nn.Module):
    """
    Args:
        d_model: the number of expected features in the input (required).
        dropout: the dropout value (default=0.1).
    """
    def __init__(self, d_model, dropout=0.1, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class TransformerModel(nn.Module):
    """
    Args:
        vocab_size: # the size of vocabulary.
        embed_size: embedding dimension, the number of expected features in the input
        nhead: the number of heads in the multiheadattention models.
        dim_feedforward: the dimension of the feedforward network model in nn.TransformerEncoder.
        nlayers: the number of sub-decoder-layers in the decoder (default=6).
        dropout: the dropout value (default=0.1).
    """
    def __init__(self,  config):
        super(TransformerModel, self).__init__()
        self.config = config
        self.pos_encoder = PositionalEncoding(config.embed_size, config.dropout, config.max_len)
        self.pos_decoder = PositionalEncoding(config.embed_size, config.dropout, config.max_len)
        self.src_embedding = CustomEmbedding(config.vocab_size, config.embed_size)
        self.tgt_embedding = nn.Embedding(config.vocab_size, config.embed_size)
        # encode images
        self.image_encoder = layers.ImageEncoder(config.embedding_size)
        # encoder
        encoder_layers = TransformerEncoderLayer(config.embed_size, config.nhead, config.dim_feedforward, config.dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, config.nlayers)
        # decoder
        decoder_layers = TransformerDecoderLayer(config.embed_size, config.nhead, config.dim_feedforward, config.dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layers, config.nlayers)
        self.linear = nn.Linear(config.embed_size, config.vocab_size)

        self.init_weights()

    def _attn_padding_mask(self, seq):
        """ seq_q: [batch_size, seq_len]
            seq_k: [batch_size, seq_len]
            seq_len could be src_len or it could be tgt_len
            seq_len in seq_q and seq_len in seq_k maybe not equal
        """
        # eq(zero) is PAD token
        return seq.data.eq(0)  # [batch_size, 1, len_k], True is masked
        #return pad_attn_mask.expand(batch_size, len_q, len_k)  # [batch_size, len_q, len_k]

    def _sequence_mask(self, seq):
        """
        Along with the input sequence, a square attention mask is required because the self-attention layers in nn.
        TransformerEncoder are only allowed to attend the earlier positions in the sequence. For the language modeling
        task, any tokens on the future positions should be masked. To have the actual words, the output of nn.
        TransformerEncoder model is sent to the final Linear layer, which is followed by a log-Softmax function.
        """
        batch_size, seq_len = seq.size()
        mask = (torch.triu(torch.ones(seq_len, seq_len)) == 1).transpose(0, 1).to(seq.device)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.src_embedding.embedding.weight.data.uniform_(-initrange, initrange)
        self.tgt_embedding.weight.data.uniform_(-initrange, initrange)

    def forward(self, src,
                input_sentence_length,
                input_conversation_length,
                tgt,
                input_images,
                input_images_length,
                input_image_indexes):
        # encode images
        img_encoder_outputs = self.image_encoder(input_images)

        # encoder
        src_padding_mask = self._attn_padding_mask(src)
        src_embed = self.src_embedding(src,
                                       img_encoder_outputs,
                                       input_images_length,
                                       input_image_indexes) * math.sqrt(self.config.embed_size)
        src_embed = self.pos_encoder(src_embed).transpose(0, 1)  # Shape must be [Len, Batch, Embed] for nn.TransformerEncoderLayer.
        memory = self.transformer_encoder(src=src_embed, src_key_padding_mask=src_padding_mask)

        # decoder
        tgt_padding_mask = self._attn_padding_mask(tgt)
        tgt_sequence_mask = self._sequence_mask(tgt)
        tgt_embed = self.tgt_embedding(tgt) * math.sqrt(self.config.embed_size)
        tgt_embed = self.pos_decoder(tgt_embed).transpose(0, 1)
        output = self.transformer_decoder(tgt=tgt_embed,
                                          memory=memory,
                                          tgt_mask=tgt_sequence_mask,
                                          tgt_key_padding_mask=tgt_padding_mask,
                                          memory_key_padding_mask=src_padding_mask)

        output = self.linear(output)

        return output.transpose(0, 1).contiguous()

Read more:

Related posts: