The transformer is a neural network architecture that was proposed in the paper "Attention is All You Need" by Ashish Vaswani, et al. in 2017. It is a powerful architecture that lead to many state-of-the-art results in the last few years. It can be used to generate text (GPT-3), create beautiful images from text (Imagen, Dall-e 2, Stable diffusion), compose music (music transformer), speech-to-text (Whisper), understanding protein structure (AlphaFold), teaching cars to drive themselves (Tesla FSD) or even learn to do many different tasks (Gato). There are probably many amazing results I forgot to name, and many more to come.

Large language models (LLMs) are going to have a large impact in the future. We are probably going to interact with these models a lot in the future to brainstorm, help us with creativity, make sense of large amounts of information and maybe even understand ourselves better. GPT-3 is already helping me with coding and writing (also for this blogpost), and these models are going to get better really fast. That is why I want to understand how they work.

In this blogpost I will show you how to build a text generation model from scratch using the transformer architecture. I will show the coding process, and will try to make each step as simple as possible. The aim of this post is not to build the best text generation model, but to try to make each step of building one as clear as possible.

Overview


For translation models a combination of encoder and decoder layers is used. The encoder reads the input sentence and the decoder generates the output sentence. For text  generation models we don't need the encoder, because the input and output sentence are the same. We can just use the decoder to generate text.

The input of the model is a sequence of integers, representing the text generated so far. Each integer corresponds to a token in a vocabulary. A token can be a character, a word or anything you might want to represent.
The output of the model is a probability distribution over the next token. We can sample from this distribution to get the next token.

This is an overview of the model. In the next sections I will zoom in on all parts of this diagram.

Tokenizing the text

We want to train our model on a text dataset. But before we can use the text in our model it has to be tokenized.

Tokenizing a text means converting it into a sequence of tokens. A token is an atomic unit of text. A token can be a character, a word or anything you might want to represent. GPT-3 uses combinations of characters that frequently occur together as tokens.

There are different ways to tokenize a text. The most common way is to use a vocabulary. A vocabulary is a set of all possible tokens. Each token in the text is represented by an integer, called an index. This integer is the index of the token in the vocabulary.

For example, if our text is "I am a cat" and our vocabulary is ["I", "am", "a", "cat", "dog", "mouse"] then our text will be represented as [1, 2, 3, 4, 0, 0]. The 0's are called padding. Padding is used to make all texts the same length.

We create a very simple tokenizer that can encode the characters a-z, the numbers 0-9 a period and a whitespace character. The padding token is mapped to 0.

Our dictionary is very simple, and contains only 39 tokens. For comparison, the dictionary that OpenAI uses for GPT-3 contains 50257 tokens.

class Tokenizer:

    def __init__(self):
        self.dictionary = {}
        self.reverse_dictionary = {}

        # Add the padding token
        self.__add_to_dict('<pad>')

        # Add characters and numbers to the dictionary
        for i in range(10):
            self.__add_to_dict(str(i))
        for i in range(26):
            self.__add_to_dict(chr(ord('a') + i))

        # Add space and punctuation to the dictionary
        self.__add_to_dict('.')
        self.__add_to_dict(' ')

    def __add_to_dict(self, character):
        if character not in self.dictionary:
            self.dictionary[character] = len(self.dictionary)
            self.reverse_dictionary[self.dictionary[character]] = character

    def tokenize(self, text):
        return [self.dictionary[c] for c in text]

    def character_to_token(self, character):
        return self.dictionary[character]

    def token_to_character(self, token):
        return self.reverse_dictionary[token]

    def size(self):
        return len(self.dictionary)

Dataset


Then we have to define a dataset to train on. I am going to define some sort of "hello world" dataset for text generation.

We create and tokenize the (mini) training dataset as follows. To make sure that all data is taken into account the training data is padded on the left.

# Create the training data
training_data = '. '.join([
    'cats rule the world',
    'dogs are the best',
    'elephants have long trunks',
    'monkeys like bananas',
    'pandas eat bamboo',
    'tigers are dangerous',
    'zebras have stripes',
    'lions are the kings of the savannah',
    'giraffes have long necks',
    'hippos are big and scary',
    'rhinos have horns',
    'penguins live in the arctic',
    'polar bears are white'
])

# Tokenize the training data
tokenized_training_data = tokenizer.tokenize(training_data)

# Add padding to the left, to make sure all parts of the sequence are being trained
for _ in range(max_sequence_length):
    # Prepend padding tokens
    tokenized_training_data.insert(0, tokenizer.character_to_token('<pad>'))

Input embedding

The embedding layer converts tokens to vector representations. The idea is to convert the token to a vector such that similar tokens are close together in the vector space. If words are used as tokens, the model will learn, for example, that cats and dogs are often used in similar contexts, so they would probably be close in vector space.

The nn.embedding layer performs the embedding step. The weights of the embedding layer are learned during the training process.

class TokenEmbedding(torch.nn.Module):
    """
    Pytorch module that converts tokens into embeddings.

    Input dimension is: (batch_size, sequence_length)
    Output dimension is: (batch_size, sequence_length, embedding_dimension)
    """

    def __init__(
            self,
            embedding_dimension,
            number_of_tokens
    ):
        super().__init__()
        self.embedding_layer = torch.nn.Embedding(
            num_embeddings=number_of_tokens,
            embedding_dim=embedding_dimension
        )

    def forward(self, x):
        return self.embedding_layer(x)

number_of_tokens indicates how many different tokens can be used in the input. This would be the number of tokens in our dictionary.

embedding_dimension denotes the size of the embedding. A larger embedding means that more information can be encoded in the vector, but it also means that the model will take longer to train.

Positional encoding

The positional encoding is used to allow the model to learn about the order of the input tokens. Every token is processed in parallel, so you could just see the input as a bag of shuffled words. Without positional encoding there is no way for the model to see the difference between the sentence "cats are bigger than mice" and "mice are bigger than cats", since they contain the same words.

To do this a positional embedding adds a vector to the embedding based on its relative position in the sentence. This embedding is based on a sinusoidal function.

for pos in range(self.max_sequence_length):
  for i in range(0, self.embedding_dimension, 2):
    P_E(pos,2i)   = sin(pos/10000^(2i/embedding_dimension))
    P_E(pos,2i+1) = cos(pos/10000^(2i/embedding_dimension))

The code to create a module that creates and adds such a positional embedding could be written as follows:

class PositionalEncoding(torch.nn.Module):
    """
    Pytorch module that creates a positional embedding with the same dimensions as the token embeddings.
    """

    def __init__(self, embedding_dimension, max_sequence_length):
        super().__init__()
        self.embedding_dimension = embedding_dimension
        self.max_sequence_length = max_sequence_length
        self.positional_encoding = self.create_positional_encoding()

    def create_positional_encoding(self):
        """
        Creates a positional encoding matrix of size (max_sequence_length, embedding_dimension)
        """
        positional_encoding = np.zeros((self.max_sequence_length, self.embedding_dimension))
        for pos in range(self.max_sequence_length):
            for i in range(0, self.embedding_dimension, 2):
                positional_encoding[pos, i] = np.sin(pos / (10000 ** ((2 * i) / self.embedding_dimension)))
                positional_encoding[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1)) / self.embedding_dimension)))
        return torch.from_numpy(positional_encoding).float()

    def forward(self, x):
        """
        Adds the positional encoding to the token embeddings.
        """
        return x + self.positional_encoding[:x.size(0), :]

Attention


To be able to learn the relationships between the tokens in the input, the transformer uses attention. An attention mechanism can be thought of as a method to focus on parts of the input that are relevant to each other. It does so by calculating an attention score. For each word, all words in the sentence are considered. If the other word is important context we want to put more emphasis on it and it should get a higher attention score.

Say you have the sentence: "The man ate a sandwich he prepared in the morning. It was topped with cheese", and we want to calculate the attention scores for "It". "Sandwich" should get a high score, because this is what "It" refers to, but "in" might receive a lower attention score. The tokens with the highest attention scores have most influence on the output.

To calculate these attention scores we need a query, key and value vector. These vectors are created by multiplying the input embedding with a learned matrix (or a linear layer in the neural network). The query, key and value vectors are then used to calculate the attention scores like in the diagram below

The attention scores for the tokens are determined by calculating the dot product between the query and key vectors.
After that the values of the resulting matrix are divided by the square root of the query/key/value dimension, because according to the paper this leads to more stable gradients.

Optionally we can apply a mask to the attention scores. This is used to prevent the model from paying attention to the padding. If the mask is 0 the attention score will be -infinity and the model will not be able to attend to it.

The attention scores are then normalized using the softmax function. After that, the values are multiplied with the attention scores and summed up. This gives us the output of the attention layer.

class MaskedSelfAttention(torch.nn.Module):
    """
    Pytorch module for a self attention layer.
    This layer is used in the MultiHeadedSelfAttention module.

    Input dimension is: (batch_size, sequence_length, embedding_dimension)
    Output dimension is: (batch_size, sequence_length, head_dimension)
    """

    def __init__(self, embedding_dimension, head_dimension):
        super().__init__()
        self.embedding_dimension = embedding_dimension
        self.head_dimension = head_dimension
        self.query_layer = torch.nn.Linear(embedding_dimension, self.head_dimension)
        self.key_layer = torch.nn.Linear(embedding_dimension, self.head_dimension)
        self.value_layer = torch.nn.Linear(embedding_dimension, self.head_dimension)
        self.softmax = torch.nn.Softmax(dim=-1)

    def forward(self, x, mask):
        """
        Compute the self attention.

        x dimension is: (batch_size, sequence_length, embedding_dimension)
        output dimension is: (batch_size, sequence_length, head_dimension)
        mask dimension is: (batch_size, sequence_length)

        mask values are: 0 or 1. 0 means the token is masked, 1 means the token is not masked.
        """

        # x dimensions are: (batch_size, sequence_length, embedding_dimension)
        # query, key, value dimensions are: (batch_size, sequence_length, head_dimension)
        query = self.query_layer(x)
        key = self.key_layer(x)
        value = self.value_layer(x)

        # Calculate the attention weights.
        # attention_weights dimensions are: (batch_size, sequence_length, sequence_length)
        attention_weights = torch.matmul(query, key.transpose(-2, -1))

        # Scale the attention weights.
        attention_weights = attention_weights / np.sqrt(self.head_dimension)

        # Apply the mask to the attention weights, by setting the masked tokens to a very low value.
        # This will make the softmax output 0 for these values.
        mask = mask.reshape(attention_weights.shape[0], 1, attention_weights.shape[2])
        attention_weights = attention_weights.masked_fill(mask == 0, -1e9)

        # Softmax makes sure all scores are between 0 and 1 and the sum of scores is 1.
        # attention_scores dimensions are: (batch_size, sequence_length, sequence_length)
        attention_scores = self.softmax(attention_weights)

        # The attention scores are multiplied by the value
        # Values of tokens with high attention score get highlighted because they are multiplied by a larger number,
        # and tokens with low attention score get drowned out because they are multiplied by a smaller number.
        # Output dimensions are: (batch_size, sequence_length, head_dimension)
        return torch.bmm(attention_scores, value)

Usually multi-headed self attention is used in transformers. A reason for using multiple heads is that it allows the model to focus on different parts of the input at the same time. Each head can learn a different representation of the data. This can be helpful for learning tasks that require understanding of the data from multiple perspectives.

In multi-headed attention there are multiple attention heads that perform the attention step I previously explained. The attention heads have a linear layer at the end that gives an output with a certain head dimension. In my example the head dimension is the embedding dimension divided by the number of heads, but you could choose a different value.

All the outputs of the head dimensions are concatenated into a single matrix. The concatenated matrix goes through a feed forward layer to give an output with the same dimensions as the input.

class MaskedMultiHeadedSelfAttention(torch.nn.Module):
    """
    Pytorch module for a multi head attention layer.

    Input dimension is: (batch_size, sequence_length, embedding_dimension)
    Output dimension is: (batch_size, sequence_length, embedding_dimension)
    """

    def __init__(self, embedding_dimension, number_of_heads):
        super().__init__()
        self.embedding_dimension = embedding_dimension
        self.head_dimension = embedding_dimension // number_of_heads
        self.number_of_heads = number_of_heads

        # Create the self attention modules
        self.self_attentions = torch.nn.ModuleList(
            [MaskedSelfAttention(embedding_dimension, self.head_dimension) for _ in range(number_of_heads)])

        # Create a linear layer to combine the outputs of the self attention modules
        self.output_layer = torch.nn.Linear(number_of_heads * self.head_dimension, embedding_dimension)

    def forward(self, x, mask):
        """
        Compute the multi head attention.

        x dimensions are: (batch_size, sequence_length, embedding_dimension)
        mask dimensions are: (batch_size, sequence_length)
        mask values are: 0 or 1. 0 means the token is masked, 1 means the token is not masked.
        """
        # Compute the self attention for each head
        # self_attention_outputs dimensions are:
        # (number_of_heads, batch_size, sequence_length, head_dimension)
        self_attention_outputs = [self_attention(x, mask) for self_attention in self.self_attentions]

        # Concatenate the self attention outputs
        # self_attention_outputs_concatenated dimensions are:
        # (batch_size, sequence_length, number_of_heads * head_dimension)
        concatenated_self_attention_outputs = torch.cat(self_attention_outputs, dim=2)

        # Apply the output layer to the concatenated self attention outputs
        # output dimensions are: (batch_size, sequence_length, embedding_dimension)
        return self.output_layer(concatenated_self_attention_outputs)

Decoder


The model consists of several decoders. Each decoder takes the output of the previous decoder as input. The first decoder takes the positional encoding layer as input. The final layer is a language model head, which is going to output the probabilies of next tokens.

When a decoder receives input from the previous layer it is normalized first. Normalization is used to make sure that the gradients don't explode. In the normalization step a mean and variance is calculated for each token. The token is then divided by the square root of the variance and subtracted by the mean, which means that the mean will be 0 and the variance will be 1.

After normalization self attention will be applied, like explained in the previous section.

The input is then added to the attention output in a residual step. Residual connections help mitigate the vanishing gradients problem. The vanishing gradient problem means the gradients get smaller and smaller when going backwards through the neural network, meaning the weights in the earlier layers don't get updated as much.

Then  the output is normalized again and a feed forward layer is applied. The feed forward layer is a fully connected layer with a ReLU activation function. In the GPT-2 paper this feed forward layer has four times the size of the embedding dimension, but this value is not set in stone.

Finally a dropout is applied. In a dropout layer connection are randomly dropped (by default with a 10% probability in my code) to prevent the model from overfitting to the training data.

The decoder layer in code

class DecoderLayer(torch.nn.Module):
    """
    Pytorch module for an encoder layer.

    An encoder layer consists of a multi-headed self attention layer, a feed forward layer and dropout.

    Input dimension is: (batch_size, sequence_length, embedding_dimension)
    Output dimension is: (batch_size, sequence_length, embedding_dimension)
    """

    def __init__(
            self,
            embedding_dimension,
            number_of_heads,
            feed_forward_dimension,
            dropout_rate
    ):
        super().__init__()
        self.embedding_dimension = embedding_dimension
        self.number_of_heads = number_of_heads
        self.feed_forward_dimension = feed_forward_dimension
        self.dropout_rate = dropout_rate

        self.multi_headed_self_attention = MaskedMultiHeadedSelfAttention(embedding_dimension, number_of_heads)
        self.feed_forward = FeedForward(embedding_dimension, feed_forward_dimension)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.layer_normalization_1 = torch.nn.LayerNorm(embedding_dimension)
        self.layer_normalization_2 = torch.nn.LayerNorm(embedding_dimension)

    def forward(self, x, mask):
        """
        Compute the encoder layer.

        x dimensions are: (batch_size, sequence_length, embedding_dimension)
        mask dimensions are: (batch_size, sequence_length)
        mask values are: 0 or 1. 0 means the token is masked, 1 means the token is not masked.
        """

        # Layer normalization 1
        normalized_x = self.layer_normalization_1(x)

        # Multi headed self attention
        attention_output = self.multi_headed_self_attention(normalized_x, mask)

        # Residual output
        residual_output = x + attention_output

        # Layer normalization 2
        normalized_residual_output = self.layer_normalization_2(residual_output)

        # Feed forward
        feed_forward_output = self.feed_forward(normalized_residual_output)

        # Dropout, only when training.
        if self.training:
            feed_forward_output = self.dropout(feed_forward_output)

        # Residual output
        return residual_output + feed_forward_output

The DecoderStack is a number of decoder layers in sequence.

class DecoderStack(torch.nn.Module):
    """
    Pytorch module for a stack of decoders.
    """

    def __init__(
            self,
            embedding_dimension,
            number_of_layers,
            number_of_heads,
            feed_forward_dimension,
            dropout_rate,
            max_sequence_length
    ):
        super().__init__()
        self.embedding_dimension = embedding_dimension
        self.number_of_layers = number_of_layers
        self.number_of_heads = number_of_heads
        self.feed_forward_dimension = feed_forward_dimension
        self.dropout_rate = dropout_rate
        self.max_sequence_length = max_sequence_length

        # Create the encoder layers
        self.encoder_layers = torch.nn.ModuleList(
            [DecoderLayer(embedding_dimension, number_of_heads, feed_forward_dimension, dropout_rate) for _ in
             range(number_of_layers)])

    def forward(self, x, mask):
        decoder_outputs = x
        for decoder_layer in self.encoder_layers:
            decoder_outputs = decoder_layer(decoder_outputs, mask)

        return decoder_outputs

The feed forward layer

class FeedForward(torch.nn.Module):
    """
    Pytorch module for a feed forward layer.

    A feed forward layer is a fully connected layer with a ReLU activation function in between.
    """

    def __init__(self, embedding_dimension, feed_forward_dimension):
        super().__init__()
        self.embedding_dimension = embedding_dimension
        self.feed_forward_dimension = feed_forward_dimension
        self.linear_1 = torch.nn.Linear(embedding_dimension, feed_forward_dimension)
        self.linear_2 = torch.nn.Linear(feed_forward_dimension, embedding_dimension)

    def forward(self, x):
        """
        Compute the feed forward layer.
        """
        return self.linear_2(torch.relu(self.linear_1(x)))


The model

Now we have the embeddings and the decoders we need to bring it all together in an autoregressive language model.

First we define the LanguageModel class, that brings the different layers together. First the token embeddings are created and a positional encoding is applied. The output of that is normalized and goes into the stack of encoders.

Finally we come to the language model head. This is a linear layer that maps the output of the decoder stack to the number of tokens in the dictionary, so we can compute probabilities for every token.

class LanguageModel(torch.nn.Module):
    """
    Pytorch module for a language model.
    """

    def __init__(
            self,
            number_of_tokens,  # The number of tokens in the vocabulary
            max_sequence_length=512,  # The maximum sequence length to use for attention
            embedding_dimension=512,  # The dimension of the token embeddings
            number_of_layers=6,  # The number of decoder layers to use
            number_of_heads=4,  # The number of attention heads to use
            feed_forward_dimension=None,  # The dimension of the feed forward layer
            dropout_rate=0.1  # The dropout rate to use
    ):
        super().__init__()
        self.number_of_tokens = number_of_tokens
        self.max_sequence_length = max_sequence_length
        self.embedding_dimension = embedding_dimension
        self.number_of_layers = number_of_layers
        self.number_of_heads = number_of_heads

        if feed_forward_dimension is None:
            # GPT-2 paper uses 4 * embedding_dimension for the feed forward dimension
            self.feed_forward_dimension = embedding_dimension * 4
        else:
            self.feed_forward_dimension = feed_forward_dimension

        self.dropout_rate = dropout_rate

        # Create the token embedding layer
        self.token_embedding = TokenEmbedding(embedding_dimension, number_of_tokens)

        # Create the positional encoding layer
        self.positional_encoding = PositionalEncoding(embedding_dimension, max_sequence_length)

        # Create the normalization layer
        self.layer_normalization = torch.nn.LayerNorm(embedding_dimension)

        # Create the decoder stack
        self.decoder = DecoderStack(
            embedding_dimension=embedding_dimension,
            number_of_layers=number_of_layers,
            number_of_heads=number_of_heads,
            feed_forward_dimension=self.feed_forward_dimension,
            dropout_rate=dropout_rate,
            max_sequence_length=max_sequence_length
        )

        # Create the language model head
        self.lm_head = LMHead(embedding_dimension, number_of_tokens)

    def forward(self, x, mask):
        # Compute the token embeddings
        # token_embeddings dimensions are: (batch_size, sequence_length, embedding_dimension)
        token_embeddings = self.token_embedding(x)

        # Compute the positional encoding
        # positional_encoding dimensions are: (batch_size, sequence_length, embedding_dimension)
        positional_encoding = self.positional_encoding(token_embeddings)

        # Post embedding layer normalization
        positional_encoding_normalized = self.layer_normalization(positional_encoding)

        decoder_outputs = self.decoder(positional_encoding_normalized, mask)
        lm_head_outputs = self.lm_head(decoder_outputs)

        return lm_head_outputs

Code for the language model head.

class LMHead(torch.nn.Module):
    """
    Pytorch module for the language model head.
    The language model head is a linear layer that maps the embedding dimension to the vocabulary size.
    """

    def __init__(self, embedding_dimension, number_of_tokens):
        super().__init__()
        self.embedding_dimension = embedding_dimension
        self.number_of_tokens = number_of_tokens
        self.linear = torch.nn.Linear(embedding_dimension, number_of_tokens)

    def forward(self, x):
        """
        Compute the language model head.

        x dimensions are: (batch_size, sequence_length, embedding_dimension)
        output dimensions are: (batch_size, sequence_length, number_of_tokens)
        """
        # Compute the linear layer
        # linear_output dimensions are: (batch_size, sequence_length, number_of_tokens)
        linear_output = self.linear(x)

        return linear_output

To complete the model we add an autoregressive wrapper (based on the implementation by lucidrains). Autoregressive means that the output of the previous step is used as input for the next. We can generate a text by adding one new character at a time this way.

The input of this wrapper is a (batch of) sequence of tokens with a lenght of max_sequence_length + 1. We add one, because this allows us to shift the target sequence by one step.

For example if you have the tokens

["badgers", "are", "nocturnal", "so", "they", "sleep", "during", "the", "day", "and", "are", "awake", "at", "night"]

our input would be

["badgers", "are", "nocturnal", "so", "they", "sleep", "during", "the", "day", "and", "are", "awake", "at"]

and our output would be shifted by one token.

["are", "nocturnal", "so", "they", "sleep", "during", "the", "day", "and", "are", "awake", "at", "night"]

Given the input, we want the model to predict the next token in the sequence, which would in this case be the word "night".

We define a mask based on the padding tokens in the input sequence. Padding tokens are not going to be attended to.

Then we define a method for this wrapper to calculate the probabilities for the next token. This method takes an input sequence and predicts the probabilities for the token that comes next, based on the trained model. It does so by calculating the logits for the last token. The logits are the output of the neural network before the softmax function is applied. The softmax function is going to convert these logits into probabilities.

The temperature can be used to control how random the predictions are. If the temperature is 0 the model will only predict the token with the highest probability. The higher the temperature, the more random the output will be.

class AutoregressiveWrapper(torch.nn.Module):
    """
    Pytorch module that wraps a GPT model and makes it autoregressive.
    """

    def __init__(self, gpt_model):
        super().__init__()
        self.model = gpt_model
        self.max_sequence_length = self.model.max_sequence_length

    def forward(self, x, mask):
        """
        Autoregressive forward pass
        """
        inp, target = x[:, :-1], x[:, 1:]
        mask = mask[:, :-1]

        output = self.model(inp, mask)
        return output, target

    def next_token_probabilities(self, x, mask, temperature=1.0):
        """
        Calculate the token probabilities for the next token in the sequence.
        """
        logits = self.model(x, mask)[:, -1]

        # Apply the temperature
        if temperature != 1.0:
            logits = logits / temperature

        # Apply the softmax
        probabilities = torch.softmax(logits, dim=-1)

        return probabilities

Trainer


Now we have a model that can learn how language works we need to actually train it to do so.

First we create the tokenizer, so we can convert our dataset to tokens.
Then we create the autoregressive language model. Because the dataset is very small I am going to only set a max_sequence_length of 20, but a more normal value would be 512.

I will then split the training text into sequences of [max_sequence_length + 1]. To make sure the starting tokens are also considered the data will be padded on the left.

Then we are going to train the model for 50 epochs, with a batch size of 8. An epoch means a complete pass over the training data. Batch size means that in every forward pass through the model we consider 8 sequences from the training data simultaneously. The higher the batch size the better the model can learn patterns in the data, but a higher batch size also leads to more memory usage.

def create_training_sequences(max_sequence_length, tokenized_training_data):
    # Create sequences of length max_sequence_length + 1
    # The last token of each sequence is the target token
    sequences = []
    for i in range(0, len(tokenized_training_data) - max_sequence_length - 1):
        sequences.append(tokenized_training_data[i: i + max_sequence_length + 1])
    return sequences


def tokenize_and_pad_training_data(max_sequence_length, tokenizer, training_data):
    # Tokenize the training data
    tokenized_training_data = tokenizer.tokenize(training_data)
    for _ in range(max_sequence_length):
        # Prepend padding tokens
        tokenized_training_data.insert(0, tokenizer.character_to_token('<pad>'))
    return tokenized_training_data


tokenizer = Tokenizer()

embedding_dimension = 256
max_sequence_length = 20
number_of_tokens = tokenizer.size()

# Create the model
model = AutoregressiveWrapper(LanguageModel(
    embedding_dimension=embedding_dimension,
    number_of_tokens=number_of_tokens,
    number_of_heads=4,
    number_of_layers=3,
    dropout_rate=0.1,
    max_sequence_length=max_sequence_length
))

# Create the training data
training_data = '. '.join([
    'cats rule the world',
    'dogs are the best',
    'elephants have long trunks',
    'monkeys like bananas',
    'pandas eat bamboo',
    'tigers are dangerous',
    'zebras have stripes',
    'lions are the kings of the savannah',
    'giraffes have long necks',
    'hippos are big and scary',
    'rhinos have horns',
    'penguins live in the arctic',
    'polar bears are white'
])

tokenized_and_padded_training_data = tokenize_and_pad_training_data(max_sequence_length, tokenizer, training_data)
sequences = create_training_sequences(max_sequence_length, tokenized_and_padded_training_data)

# Train the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
trainer = Trainer(model, tokenizer, optimizer)
trainer.train(sequences, epochs=100, batch_size=8)

The trainer is a helper class that loops over the epochs and shuffles the data at the start of each epoch. The reason to do this is to prevent the batches from being the same every time, causing the model to overfit to these specific batches.

While creating batches we also determine the mask. All padding tokens are masked, meaning they will not be considered in the attention step.

Then we do a forward pass through the model with a batch. This means we let the model make predictions using the given data. The predictions are then compared to a target value, which is the sequence shifted by 1 step so the next token becomes visible. The model outputs probabilities for what token should be the next. The loss function knows what the answer should be. The further from its target the prediction was the higher the loss value will be.

When the loss value is calculated the model can be updated. This is done by calculating gradients; the direction the weights should be adjusted to improve the prediction of the model. The model is then slightly adjusted in the direction of the gradients, and a new batch can be processed.

If everything works as planned, the loss should go down over time. I return the loss per epoch, so it can be plotted.

class Trainer:

    def __init__(self, model, tokenizer: Tokenizer, optimizer=None):
        super().__init__()
        self.model = model
        if optimizer is None:
            self.optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        else:
            self.optimizer = optimizer
        self.tokenizer = tokenizer
        self.loss_function = torch.nn.CrossEntropyLoss()

    def train(self, data: List[str], epochs, batch_size):
        loss_per_epoch = []
        for epoch in range(epochs):
            losses = []

            # Shuffle the sequences
            random.shuffle(data)

            # Create batches of sequences and their respective mask.
            batches = []
            for i in range(0, len(data), batch_size):
                sequence_tensor = torch.tensor(data[i: i + batch_size], dtype=torch.long)

                # Create the mask tensor for the batch, where 1 means the token is not a padding token
                mask_tensor = torch.ones_like(sequence_tensor)
                mask_tensor[sequence_tensor == self.tokenizer.character_to_token('<pad>')] = 0

                batches.append((sequence_tensor, mask_tensor))

            # Train the model on each batch
            for batch in batches:
                self.model.train()

                # Create the input and mask tensors
                input_tensor = torch.zeros((batch_size, self.model.max_sequence_length + 1), dtype=torch.long)
                mask_tensor = torch.zeros((batch_size, self.model.max_sequence_length + 1), dtype=torch.long)

                for i, input_entry in enumerate(batch[0]):
                    input_tensor[i] = input_entry

                for i, mask_entry in enumerate(batch[1]):
                    mask_tensor[i] = mask_entry

                # Compute the model output
                model_output, target = self.model.forward(x=input_tensor, mask=mask_tensor)

                # Compute the losses
                # The loss is computed on the model output and the target
                loss = self.loss_function(model_output.transpose(1, 2), target)

                # Backpropagate the loss.
                loss.backward()

                # Clip the gradients. This is used to prevent exploding gradients.
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.5)

                # Update the model parameters. This is done by taking a step in the direction of the gradient.
                self.optimizer.step()

                # Reset the gradients. This is done so that the gradients from the previous batch
                # are not used in the next step.
                self.optimizer.zero_grad()

                # Append the loss to the list of losses, so that the average loss can be computed for this epoch.
                losses.append(loss.item())

            # Print the loss
            epoch_loss = np.average(losses)
            loss_per_epoch.append(epoch_loss)
            print('Epoch:', epoch, 'Loss:', epoch_loss)

        return loss_per_epoch

Evenutally plotting the loss should give us a nice graph with decreasing loss. It is plotted in log scale, so you can see the smaller variations towards the end of training.

# Plot the loss per epoch in log scale
plt.plot(loss_per_epoch)
plt.yscale('log')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.show()

Generator

Now the model is trained and I want to see if it actually learned to write. Let's try to generate a text based on the prompt "elephants". I want it to continue writing for 50 tokens.  

Since the only mention of elephants in the training data is "elephants have long trunks", I expect the model to write this.

max_tokens_to_generate = 50
generator = Generator(model, tokenizer)
generated_text = generator.generate(
    max_tokens_to_generate=max_tokens_to_generate,
    prompt="elephants",
    padding_token=tokenizer.character_to_token('<pad>')
)
print(generated_text.replace('<pad>', ''))

but first we need to write the code for the Generator. A helper class for generating text.

First we switch the model from "training" mode to "eval" mode. In the eval mode the model will not apply dropout.

The prompt we give is converted to tokens, and then padded so it has the correct sequence length.

Then we are going to auto-regressively generate new tokens and add them to the input sequence. After a token is added we run the new input sequence with the extra token through the model again, and we append a new token. We continue this process until the maximum number of characters we wanted to generate is reached, or until we have generated the eos_token, or end of sequence token. This is a token that can be defined by the user as an indication that we need to stop generating.

def pad_left(sequence, final_length, padding_token):
    return [padding_token] * (final_length - len(sequence)) + sequence


class Generator:

    def __init__(
            self,
            model,
            tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(
            self,
            max_tokens_to_generate: int,
            prompt: str = None,
            temperature: float = 1.0,
            eos_token: int = None,
            padding_token: int = 0):

        self.model.eval()

        if prompt is None:
            start_tokens = [self.tokenizer.character_to_token(padding_token)]
        else:
            start_tokens = self.tokenizer.tokenize(prompt)

        input_tensor = torch.tensor(
            pad_left(
                sequence=start_tokens,
                final_length=self.model.max_sequence_length + 1,
                padding_token=padding_token
            ),
            dtype=torch.long
        )

        num_dims = len(input_tensor.shape)

        if num_dims == 1:
            input_tensor = input_tensor[None, :]

        out = input_tensor
        for _ in range(max_tokens_to_generate):

            x = out[:, -self.model.max_sequence_length:]

            mask = torch.ones_like(x)
            mask[x == padding_token] = 0

            # Compute the next token probabilities
            next_token_probabilities = self.model.next_token_probabilities(
                x=x,
                temperature=temperature,
                mask=mask
            )

            # Sample the next token from the probability distribution
            next_token = torch.multinomial(next_token_probabilities, num_samples=1)

            # Append the next token to the output
            out = torch.cat([out, next_token], dim=1)

            # If the end of sequence token is reached, stop generating tokens
            if eos_token is not None and next_token == eos_token:
                break

        generated_tokens = out[0].tolist()
        return ''.join([self.tokenizer.token_to_character(token) for token in generated_tokens])

The training is finished and the generator has run. aaaaaanndd... the model actually outputs the text we trained it on!

Of course, when training on such a small dataset the model will completely overfit it and learn to reproduce the whole dataset. Not what you are usually looking for in a language model, but in this "hello world" test it means success!

Saving and loading the model


Once you trained the model, it is useful if you can save it, so you don't have to train a new model every time.

To do this we add the following code to the LanguageModel class.

def save_checkpoint(self, path):
    print(f'Saving checkpoint {path}')
    torch.save({
        'number_of_tokens': self.number_of_tokens,
        'max_sequence_length': self.max_sequence_length,
        'embedding_dimension': self.embedding_dimension,
        'number_of_layers': self.number_of_layers,
        'number_of_heads': self.number_of_heads,
        'feed_forward_dimension': self.feed_forward_dimension,
        'dropout_rate': self.dropout_rate,
        'model_state_dict': self.state_dict()
    }, path)

@staticmethod
def load_checkpoint(path) -> 'LanguageModel':
    checkpoint = torch.load(path)
    model = LanguageModel(
        number_of_tokens=checkpoint['number_of_tokens'],
        max_sequence_length=checkpoint['max_sequence_length'],
        embedding_dimension=checkpoint['embedding_dimension'],
        number_of_layers=checkpoint['number_of_layers'],
        number_of_heads=checkpoint['number_of_heads'],
        feed_forward_dimension=checkpoint['feed_forward_dimension'],
        dropout_rate=checkpoint['dropout_rate']
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    return model

Since we use the AutoregressiveWrapper as convenience class, we can give this wrapper the save and load methods too.

def save_checkpoint(self, path):
    self.model.save_checkpoint(path)

@staticmethod
def load_checkpoint(path) -> 'AutoregressiveWrapper':
    model = LanguageModel.load_checkpoint(path)
    return AutoregressiveWrapper(model)

This makes it possible to easily save and load a trained model using.

model.save_checkpoint('./trained_model')
model = model.load_checkpoint('./trained_model')

Running on GPU


If you have a GPU at your disposal and Cuda is configured, you can use a GPU to speed up training.

First we need to define a function to determine if we can use a GPU:

def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')

Then we need to move the model, and all the input tensors to our device. We can do this by using the torch function "to".

# Model
model = AutoregressiveWrapper(LanguageModel(
    embedding_dimension=embedding_dimension,
    number_of_tokens=number_of_tokens,
    number_of_heads=4,
    number_of_layers=3,
    dropout_rate=0.1,
    max_sequence_length=max_sequence_length
)).to(get_device())

# Training input
model_output, target = self.model.forward(
    x=input_tensor.to(get_device()),
    mask=mask_tensor.to(get_device())
)

# Generation input
input_tensor = torch.tensor(
    pad_left(
        sequence=start_tokens,
        final_length=self.model.max_sequence_length + 1,
        padding_token=padding_token
    ),
    dtype=torch.long
).to(get_device())

The speedup achieved by using a GPU can be really significant.

Conclusion


While writing this blogpost I learned a lot about how the transformer works, and how we can use it to generate text. I hope you enjoyed reading it, and learned something new too!

Feel free to leave a comment if you have remarks, questions or just want to let me know what you think :)

The code for this blogpost is available at https://github.com/wingedsheep/transformer

Sources


Some resources that were incredibly helpful in helping me understand attention, the transformer architecture and GPT models.