Building BERT with PyTorch from scratch

I used to work with BERT and its siblings but I never paid a lot of attention to “what it actually is”. I’ve decided to fill this shameful gap in my knowledge and build at least its minimal working version. Searching for the tutorial didn’t help me much, I had to gather the knowledge in little pieces to get a full picture of BERT.

This article is my attempt to create a thorough tutorial on how to build BERT architecture using PyTorch.

The full code to the tutorial is available at pytorch_bert.

Long Story Short about BERT

BERT stands for Bidirectional Encoder Representation from Transformers. The original BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding, actually, explains everything you need to know about BERT.

Honestly saying, there are much better articles on the Internet explaining what BERT is, for example, BERT Explained: State of the art language model for NLP. After reading this article, you may have some questions about attention mechanisms; this article Illustrated: Self-Attention explains attentions.

In this paragraph I just want to run over the ideas of BERT and give more attention to the practical implementation.

BERT solves two tasks simultaneously:

  • Next Sentence Prediction (NSP) ;
  • Masked Language Model (MLM).
Source [devlin et al, 2018].

Next Sentence Prediction

NSP is a binary classification task. Having two sentences in input, our model should be able to predict if the second sentence is a true continuation of the first sentence.

For example, there is a paragraph

I have a cat named Tom. Tom likes to play with birds sitting on the window.
They like this game not. I also have a dog. We walk together everyday.

The potential dataset will look like

Sentence NSP class
I have a cat named Tom. Tom likes to play with birds sitting on the window is next
I have a cat named Tom. We walk together everyday is not next

Masked Language Model

Masked Language Model is the task to predict the hidden words in the sentence.

For example, having a sentence

Tom likes to [MASK] with birds [MASK] on the window.

the model should predict that masked words are play and sitting.

Building BERT

To build BERT we need to work out three steps:

  • Prepare Dataset;
  • Build a model;
  • Build a trainer.

Prepare Dataset

In the case of BERT, the dataset should be prepared in a certain way. I spent maybe 30% of the time and my brain power only to build the dataset for the BERT model. So, it’s worth a discussion in its own paragraph.

The original BERT uses BooksCorpus (800M words) and English Wikipedia (2,500M words) for pre-training. We use IMDB reviews data with ~72k words.

Download the dataset from Kaggle: IMDB Dataset of 50K Movie Reviews and put it under data/ directory in the root of your project.

Following, pytorch’s DATASETS & DATALOADERS we have to create Dataset that extends torch.utils.data.Dataset class.

class IMDBBertDataset(Dataset):
    # Define Special tokens as attributes of class
    CLS = '[CLS]'
    PAD = '[PAD]'
    SEP = '[SEP]'
    MASK = '[MASK]'
    UNK = '[UNK]'

    MASK_PERCENTAGE = 0.15  # How much words to mask

    MASKED_INDICES_COLUMN = 'masked_indices'
    TARGET_COLUMN = 'indices'
    NSP_TARGET_COLUMN = 'is_next'
    TOKEN_MASK_COLUMN = 'token_mask'

    OPTIMAL_LENGTH_PERCENTILE = 70

    def __init__(self, path, ds_from=None, ds_to=None, should_include_text=False):
        self.ds: pd.Series = pd.read_csv(path)['review']

        if ds_from is not None or ds_to is not None:
            self.ds = self.ds[ds_from:ds_to]

        self.tokenizer = get_tokenizer('basic_english')
        self.counter = Counter()
        self.vocab = None

        self.optimal_sentence_length = None
        self.should_include_text = should_include_text

        if should_include_text:
            self.columns = ['masked_sentence', self.MASKED_INDICES_COLUMN, 'sentence', self.TARGET_COLUMN,
                            self.TOKEN_MASK_COLUMN,
                            self.NSP_TARGET_COLUMN]
        else:
            self.columns = [self.MASKED_INDICES_COLUMN, self.TARGET_COLUMN, self.TOKEN_MASK_COLUMN,
                            self.NSP_TARGET_COLUMN]
        self.df = self.prepare_dataset()

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        ...
    
    def prepare_dataset() -> pd.DataFrame:
        ...

A bit of a strange part in the __init__ is

...
if should_include_text:
    self.columns = ['masked_sentence', self.MASKED_INDICES_COLUMN, 'sentence', self.TARGET_COLUMN,
                    self.TOKEN_MASK_COLUMN,
                    self.NSP_TARGET_COLUMN]
else:
    self.columns = [self.MASKED_INDICES_COLUMN, self.TARGET_COLUMN, self.TOKEN_MASK_COLUMN,
                    self.NSP_TARGET_COLUMN]
...

We define the above columns to create self.df. Use should_include_text=True to include textual representation of the created sentences in the data frame. It’s useful to see what exactly was created by our preprocessing algorithm.

So, set should_include_text=True only for debug purpose.

Most of the work will be done in the prepare_dataset method. In the __getitem__ method we prepare a training item tensors.

To prepare dataset, we do next:

  • Split dataset on sentences
  • Create vocabulary for word - token pair, for example {'go': 45}
  • Create training dataset
    • Add special tokens to the sentence
    • Mask 15% of words in the sentence
    • Pad sentence to predefined length
    • Create NSP item from two sentences

Let’s review the code of prepare_dataset method step by step.

Split dataset on sentences and fill the vocabulary

Retrieving sentences is the first (and the simplest) operation we do in the prepare_dataset method. It is needed to fill the vocabulary of words.

sentences = []  
nsp = []  
sentence_lens = []

# Split dataset on sentences
for review in self.ds:
    review_sentences = review.split('. ')
    sentences += review_sentences
    self._update_length(review_sentences, sentence_lens)
self.optimal_sentence_length = self._find_optimal_sentence_length(sentence_lens)

Note that we split text by . . But as stated in [devlin et al, 2018], a sentence can have arbitrary amount of contiguous text; you can split it however you need.

If you print sentences[:2] you’ll see the following result

['One of the other reviewers has mentioned that after watching just 1 Oz '
 "episode you'll be hooked",
 'They are right, as this is exactly what happened with me.<br /><br />The '
 'first thing that struck me about Oz was its brutality and unflinching scenes '
 'of violence, which set in right from the word GO']

Interesting part is how we define a sentence length.

def _find_optimal_sentence_length(self, lengths: typing.List[int]):  
    arr = np.array(lengths)  
    return int(np.percentile(arr, self.OPTIMAL_LENGTH_PERCENTILE))

Instead of hardcoded max length, we store all lengths of sentences in a list and calculate the 70 percentile of the sentence_lens. For 50k IMDB, the optimal sentence length value is 27. It means that 70% of sentences have length less or equal than 27.

Then, we feed these sentences to the vocabulary. We tokenize each sentence and update the counter with the sentence tokens (words).

print("Create vocabulary")  
for sentence in tqdm(sentences):  
    s = self.tokenizer(sentence)  
    self.counter.update(s)  
  
self._fill_vocab()

The sentence after tokenization is the list of its words

"My cat is Tom" -> ['my', 'cat', 'is', 'tom']

Here is the output you should see after printing the self.counter

Counter({'the': 6929,
         ',': 5753,
         'and': 3409,
         'a': 3385,
         'of': 3073,
         'to': 2774,
         "'": 2692,
         '.': 2184,
         'is': 2123,
         ...

Note in this tutorial we omit the important step of the dataset cleaning. That’s the reason why the most popular tokens are the, ,, and, a, and so on.

Finally, we’re ready to build our vocabulary. This operation is moved to _fill_vocab method

def _fill_vocab(self):  
    # specials= argument is only in 0.12.0 version  
    # specials=[self.CLS, self.PAD, self.MASK, self.SEP, self.UNK]
    self.vocab = vocab(self.counter, min_freq=2)  

    # 0.11.0 uses this approach to insert specials  
    self.vocab.insert_token(self.CLS, 0)  
    self.vocab.insert_token(self.PAD, 1)  
    self.vocab.insert_token(self.MASK, 2)  
    self.vocab.insert_token(self.SEP, 3)  
    self.vocab.insert_token(self.UNK, 4)  
    self.vocab.set_default_index(4)

For this tutorial, we’ll add to the vocabulary only the words that appear 2 or more times in the dataset. After vocabulary is created, we add special tokens to the vocabulary and set [UNK] token as default.

Huh, half of the work is done 🎉 we have built a vocabulary. Let’s test it

self.vocab.lookup_indices(["[CLS]", "this", "works", "[MASK]", "well"])

and the output

[0, 29, 1555, 2, 152]

Create training dataset

For each review with more than one linguistic sentence we create true NSP item (when the second sentence is the next sentence in review) and false NSP item (when the second sentence is any random sentence from the sentences).

print("Preprocessing dataset")  
for review in tqdm(self.ds):  
    review_sentences = review.split('. ')  
    if len(review_sentences) > 1:  
        for i in range(len(review_sentences) - 1):  
            # True NSP item  
            first, second = self.tokenizer(review_sentences[i]), self.tokenizer(review_sentences[i + 1])  
            nsp.append(self._create_item(first, second, 1))  
  
            # False NSP item  
            first, second = self._select_false_nsp_sentences(sentences)  
            first, second = self.tokenizer(first), self.tokenizer(second)  
            nsp.append(self._create_item(first, second, 0))  
df = pd.DataFrame(nsp, columns=self.columns)

_create_item method does 99% of the work. The following code is more trickier than vocabulary creation. So, do not hesitate to run the code in the debug mode. Let’s take a look at what happens with a sentence pair after each transformation step by step. Full implementation of self._create_item method is in the repository.

First thing we should do is to add special tokens ([CLS], [PAD], [MASK]) in our sentences

def _create_item(self, first: typing.List[str], second: typing.List[str], target: int = 1):  
    # Create masked sentence item  
    updated_first, first_mask = self._preprocess_sentence(first.copy())  
    updated_second, second_mask = self._preprocess_sentence(second.copy())
    nsp_sentence = updated_first + [self.SEP] + updated_second  
    nsp_indices = self.vocab.lookup_indices(nsp_sentence)  
    inverse_token_mask = first_mask + [True] + second_mask

Step #1. Mask sentence

Important, also, to see how we mask tokens of our sentences

def _mask_sentence(self, sentence: typing.List[str]):  
    len_s = len(sentence)  
    inverse_token_mask = [True for _ in range(max(len_s, self.optimal_sentence_length))]  
  
    mask_amount = round(len_s * self.MASK_PERCENTAGE)  
    for _ in range(mask_amount):  
        i = random.randint(0, len_s - 1)  
  
        if random.random() < 0.8:  
            sentence[i] = self.MASK  
        else:
            sentence[i] = self.vocab.lookup_token(j)  
        inverse_token_mask[i] = False  
 return sentence, inverse_token_mask

We update random 15% tokens in the sentence. Pay attention, that for 80% of these cases we set [MASK] token, otherwise we set random word from the vocabulary.

Unclear part in the code above is inverse_token_mask. This list has True value when token in sentence is masked. For example, let’s take a sentence

my cat tom likes to sleep and does not like little mice jerry

After masking the sentence and inverse token mask looks like

sentence: My cat mice likes to sleep and does not like [MASK] mice jerry
inverse token mask: [False, False, True, False, False, False, False, False, False, False, True, False, False]

We’ll get back to the inverse token mask later when we will train our model.

Aside of masked sentences, we store original unmasked sentences that we will use later as MLM training target

# Create sentence item without masking random words  
first, _ = self._preprocess_sentence(first.copy(), should_mask=False)  
second, _ = self._preprocess_sentence(second.copy(), should_mask=False)  
original_nsp_sentence = first + [self.SEP] + second  
original_nsp_indices = self.vocab.lookup_indices(original_nsp_sentence)

Step #2. Preprocessing: [CLS] and [PAD] sentence

Now we need to prepend [CLS] to the beginning of each sentence. Afterwards, we add [PAD] token to the end of each sentence for them to have equal lengths. Let’s say we should align all sentences to the length of 13.

After transformation we have a sentence

[CLS] My cat mice likes to sleep and does not like [MASK] mice jerry
[SEP]
[CLS] jerry is treated as my pet too [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

_pad_sentence method takes care of this transformation

def _pad_sentence(self, sentence: typing.List[str], inverse_token_mask: typing.List[bool] = None):  
    len_s = len(sentence)  
  
    if len_s >= self.optimal_sentence_length:  
        s = sentence[:self.optimal_sentence_length]  
    else:  
        s = sentence + [self.PAD] * (self.optimal_sentence_length - len_s)  
  
    # inverse token mask should be padded as well  
    if inverse_token_mask:  
        len_m = len(inverse_token_mask)  
        if len_m >= self.optimal_sentence_length:  
            inverse_token_mask = inverse_token_mask[:self.optimal_sentence_length]  
        else:  
            inverse_token_mask = inverse_token_mask + [True] * (self.optimal_sentence_length - len_m)  
    return s, inverse_token_mask

Note inverse token mask must have the same length as the sentence, so you should pad it as well.

Step #3. Translate words in the sentence to the integer tokens

Using our pre-trained vocabulary, we now translate the sentence into tokens.

It’s done by two lines of code

...
nsp_sentence = updated_first + [self.SEP] + updated_second  
nsp_indices = self.vocab.lookup_indices(nsp_sentence)
...

Firstly, we join two sentences by [SEP] token and then translate to the list of integers.

After you run the dataset.py module as script, you should see preprocessed dataset

                                        masked_sentence  ... is_next
0     [[CLS], [MASK], of, the, other, reviewers, has...  ...       1
1     [[CLS], once, fifteen, arrived, in, the, ameri...  ...       0
2     [[CLS], they, [MASK], [MASK], ,, as, this, is,...  ...       1
3     [[CLS], just, a, [MASK], of, [MASK], young, ma...  ...       0
4     [[CLS], trust, me, [MASK], this, is, [MASK], a...  ...       1
                                                    ...  ...     ...
8873  [[CLS], freshness, crystal, is, here, to, sell...  ...       0
8874  [[CLS], pixar, have, proved, that, they, ', re...  ...       1
8875  [[CLS], [MASK], abandons, her, slapstick, [MAS...  ...       0
8876  [[CLS], they, raise, the, bar, [MASK], ,, and,...  ...       1
8877  [[CLS], he, is, an, amazing, [MASK], artist, ,...  ...       0
[8878 rows x 6 columns]

Printing the first item in the dataframe print(self.df.iloc[0]) gives us

masked_sentence    [[CLS], one, of, the, other, [MASK], has, ment...
masked_indices     [0, 5, 6, 7, 8, 2, 10, 11, 4825, 13, 2, 15, 16...
sentence           [[CLS], one, of, the, other, reviewers, has, m...
indices            [0, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,...
token_mask         [True, True, True, True, False, True, True, Fa...
is_next                                                            1
Name: 0, dtype: object

__get_item__

Now, we’re ready to write __getitem__ method.

item = self.df.iloc[idx]

inp = torch.Tensor(item[self.MASKED_INDICES_COLUMN]).long()
token_mask = torch.Tensor(item[self.TOKEN_MASK_COLUMN]).bool()

attention_mask = (inp == self.vocab[self.PAD]).unsqueeze(0)

At first, we select the item from the dataframe and create tensors that’ll be used for model training.

attention_mask has True value when input token is [PAD]. We use it in the training process to extinguish the embeddings for [PAD] tokens.

We have inputs to the model, but we also should have the targets for training.

NSP target

NSP is a binary classification problem.

if item[self.NSP_TARGET_COLUMN] == 0:  
    t = [1, 0]  
else:  
    t = [0, 1]  
  
nsp_target = torch.Tensor(t)

We create NSP target as tensor of two items. It can be only in two states designating whether it is next sentence or not.

[1, 0] is NOT next
[0, 1] is next

To train the model for NSP we use BCEWithLogitsLoss class. It expects the target class to be in the above format.

MLM target

We want our model to predict only masked tokens

mask_target = torch.Tensor(item[self.TARGET_COLUMN]).long()  
mask_target = mask_target.masked_fill_(token_mask, 0)

We directly set all non-masked integers in the target to 0. Looking ahead, we’ll do the same for the model output.

Build pyTorch model

As you already noticed the project is parted on the submodules under bert package. The full neural network model is located in model.py file.

Firstly, I’d like to show you the object diagram of the model. Then we’ll go through code step by step.

Not so difficult, right? Let’s review it step by step.

JointEmbedding

We’re starting the model description from Embeddings. BERT has three embedding layers

  • Token embedding
  • Segment embedding
  • Position embedding
Source [devlin et al, 2018].

Token embedding is used to encode word tokens. Segment embedding encodes belonging to the first or to the second sentence. We preprocess input sequence the next way: if the token belongs to the first sentence, set 0, otherwise set 1. For example,

Input tokens:   [0, 6, 24, 565, 67, 0, 443, 123, 5, 6, 5, 12, 1, 1, 1]
Input Segments: [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]

Positional embedding encodes the position of the word in the sentence. There is an option to use embedding layer to encode positional information of token in a sequence. In the module’s code it’s done in numeric_position method. What it does is just arrange integer position.

Input tokens:   [0, 6, 24, 565, 67, 0, 443, 123, 5, 6, 5, 12, 1, 1, 1]
Input position: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]

However, instead of learnable positional embedding, we use periodic functions to encode positions (as stated at [vaswani et al, 2017]).

These functions of two variables well explained in StackExchange: What is the positional encoding in the transformer model?).

So, pos variable is the concrete value on the sin curve. i is the concrete sin curve from which to select pos. So, for i = 0 we have one periodic curve to get pos from. For i = 4 we have another.

We keep all embeddings in one module JoinEmbedding. Here is the full code for the module.

class JointEmbedding(nn.Module):

    def __init__(self, vocab_size, size):
        super(JointEmbedding, self).__init__()

        self.size = size

        self.token_emb = nn.Embedding(vocab_size, size)
        self.segment_emb = nn.Embedding(vocab_size, size)

        self.norm = nn.LayerNorm(size)

    def forward(self, input_tensor):
        sentence_size = input_tensor.size(-1)
        pos_tensor = self.attention_position(self.size, input_tensor)

        segment_tensor = torch.zeros_like(input_tensor).to(device)
        segment_tensor[:, sentence_size // 2 + 1:] = 1

        output = self.token_emb(input_tensor) + self.segment_emb(segment_tensor) + pos_tensor
        return self.norm(output)

    def attention_position(self, dim, input_tensor):
        batch_size = input_tensor.size(0)
        sentence_size = input_tensor.size(-1)

        pos = torch.arange(sentence_size, dtype=torch.long).to(device)
        d = torch.arange(dim, dtype=torch.long).to(device)
        d = (2 * d / dim)

        pos = pos.unsqueeze(1)
        pos = pos / (1e4 ** d)

        pos[:, ::2] = torch.sin(pos[:, ::2])
        pos[:, 1::2] = torch.cos(pos[:, 1::2])

        return pos.expand(batch_size, *pos.size())

    def numeric_position(self, dim, input_tensor):
        pos_tensor = torch.arange(dim, dtype=torch.long).to(device)
        return pos_tensor.expand_as(input_tensor)

As you see, from the code, we create two embedding layers

self.token_emb = nn.Embedding(vocab_size, size)
self.segment_emb = nn.Embedding(vocab_size, size)

Then in the forward method we calculate positional encoding tensor and prepare a tensor for segment Embedding

pos_tensor = self.attention_position(self.size, input_tensor)

segment_tensor = torch.zeros_like(input_tensor).to(device)
segment_tensor[:, sentence_size // 2 + 1:] = 1

Then we just sum them and pass the resulting tensor through LayerNorm.

output = self.token_emb(input_tensor) + self.segment_emb(segment_tensor) + pos_tensor
return self.norm(output)

Using learnable positional embedding

If you want to use learnable positional embedding, you should create one more nn.Embedding instance in the constructor.

self.positional_emb = nn.Embedding(vocab_size, size)

Then in the forward method use numeric_position method to generate the input for self.positional_emb attribute.

pos_tensor = self.numeric_position(self.size, input_tensor)
...
output = self.token_emb(input_tensor) + self.segment_emb(segment_tensor) + self.positional_emb(pos_tensor)


Implementation of positional encoding in this code can be somewhat hard for understanding, so let’s review it in more details with an example.

We have a batch size of 2 and max sentence length of 5. dim attribute is the embedding size, we set it to 4. input_tensor is the word tokens tensor with size (batch_size x sentence_size), in our example, (2 x 5).

def attention_position(self, dim = 4, input_tensor):  # input_tensor shape (2 x 5) 
    batch_size = 2
    sentence_size = 5

Now, we should create a vector of numerical encoding of word in a sentence.

pos = tensor([0, 1, 2, 3, 4, 5])

Vector d represents positional encoding by the embedding axis, in the formula it is

d = tensor([0, 1, 2, 3])
d = (2 * d / dim) = tensor([0, 0.5, 1, 1.5])

On the next two lines, we see the process that is called Broadcasting. PyTorch’s broadcasting is strongly inspired by NumPy (read more on NumPy Broadcasting).

Firstly, we unsqueeze pos vector to the shape (5 x 1). Then we multiply it position-wise with d vector using PyTorch broadcasting. Relating to the formula we’re calculating a value that we pass further to periodic function

pos = pos.unsqueeze(1) = tensor([[0],
                                 [1],
                                 [2],
                                 [3],
                                 [4]])

pos = pos / (1e4 ** d) = size (5 x 4) = 
= tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
          [1.0000e+00, 1.0000e-02, 1.0000e-04, 1.0000e-06],
          [2.0000e+00, 2.0000e-02, 2.0000e-04, 2.0000e-06],
          [3.0000e+00, 3.0000e-02, 3.0000e-04, 3.0000e-06],
          [4.0000e+00, 4.0000e-02, 4.0000e-04, 4.0000e-06]])

A final operation is just to apply periodic functions to this pos tensor.

pos[:, ::2] = torch.sin(pos[:, ::2])    # Apply to 2i
pos[:, 1::2] = torch.cos(pos[:, 1::2])  # Apply to 2i + 1 

And extending pos on every element in batch.

return pos.expand(batch_size,  *pos.size()) = size (2 x 5 x 4)

The returned tensor has the same size as our other embeddings. We just simulated embedding layer with simple mathematical operations.

AttentionHead

Attention is the heart of transformers. This is exactly what makes transformers so good. BERT uses the mechanism called self-attention. It is well described in this article Illustrated: Self-Attention. And a quote from this resource

A self-attention module takes in n inputs and returns n outputs. What happens in this module? In layman’s terms, the self-attention mechanism allows the inputs to interact with each other (“self”) and find out who they should pay more attention to (“attention”). The outputs are aggregates of these interactions and attention scores.

There are various types of attentions that can be used. We use the one from [vaswani et al, 2017].

where Q is query, K is key, and V is value.

For each of them we create linear layer with trainable weights. So, we’ll teach the network to ‘pay attention’. On the pictures below, you may see the visualization of attention multiplication. That’s exactly what we do in the code.

Visualization of attention operations from The Illustrated Transformer.
class AttentionHead(nn.Module):  
  
    def __init__(self, dim_inp, dim_out):  
        super(AttentionHead, self).__init__()  
  
        self.dim_inp = dim_inp  
  
        self.q = nn.Linear(dim_inp, dim_out)  
        self.k = nn.Linear(dim_inp, dim_out)  
        self.v = nn.Linear(dim_inp, dim_out)  
  
    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor = None):  
        query, key, value = self.q(input_tensor), self.k(input_tensor), self.v(input_tensor)  
  
        scale = query.size(1) ** 0.5  
        scores = torch.bmm(query, key.transpose(1, 2)) / scale  
  
        scores = scores.masked_fill_(attention_mask, -1e9)  
        attn = f.softmax(scores, dim=-1)  
        context = torch.bmm(attn, value)  
  
        return context

As you see in the __init__, we create linear module for query, key, and value. For simplicity in this tutorial they all have same shapes.

Let’s continue above’s example. dim_inp is the size of embedding and equals to 4. Let’s take hidden attention size dim_out as 6.

Let’s follow the forward method step by step. Instead of printing the values of tensors we’ll just print their sizes (shapes).

# input tensor is the output of JointEmbedding module
# attention mask is the vector that masks [PAD] tokens
def forward(self, input_tensor: size (2 x 5 x 4), attention_mask: size (2 x 1 x 5)):

First thing we do is calculate query, key, value tensors

query, key, value = size (2 x 5 x 6), size (2 x 5 x 6), size (2 x 5 x 6)

Further, we calculate scaled multiplication of query and key.

scale = query.size(1) ** 0.5  
scores = torch.bmm(query, key.transpose(1, 2)) / scale = size (2 x 5 x 5) 

torch.bmm is batched matrix multiplication function. This multiplies each matrix in a batch, skipping the first axis. transpose method transposes tensor for 2 specific dimensions.

We don’t want our model to ‘pay attention’ on [PAD] tokens at all. This is why we have attention mask vector. Using this vector we shadow the scores of [PAD] tokens.

scores = scores.masked_fill_(attention_mask, -1e9) = size (2 x 5 x 5)

Now, we calculate the attention context itself.

attn = f.softmax(scores, dim=-1) = size (2 x 5 x 5)
context = torch.bmm(attn, value) = size (2 x 5 x 6)

So, each input value was weighted by the attention tensor.

MultiHeadAttention

Single attention layer (head) is restricted to learn only the information from one particular subspace. Multi-head attention is the set of parallel attention heads that learns to retrieve the information from different representations. You may look on them as on filters in Convolutional Neural Networks.

Diagrams of single attention head and multi-head attention from [vaswani et al, 2017].

You may see how it works on the visualization of two-head attention on a picture below. We print the attentions for the word it. The first attention (orange) scores word animal the most while the second attention (green) scores word tired the most.

Two-head attention visualization from The Illustrated Transformer

Let’s back to our code. As always, here is the full code of the module, then we go through it step by step

class MultiHeadAttention(nn.Module):  
  
    def __init__(self, num_heads, dim_inp, dim_out):  
        super(MultiHeadAttention, self).__init__()  
  
        self.heads = nn.ModuleList([  
            AttentionHead(dim_inp, dim_out) for _ in range(num_heads)  
        ])  
        self.linear = nn.Linear(dim_out * num_heads, dim_inp)  
        self.norm = nn.LayerNorm(dim_inp)  
  
    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):  
        s = [head(input_tensor, attention_mask) for head in self.heads]  
        scores = torch.cat(s, dim=-1)  
        scores = self.linear(scores)  
        return self.norm(scores)

dim_inp and dim_out had the same values as in AttentionHead paragraph: dim_inp equals to 4, dim_out equals to 6. num_heads is 3. For simplicity, we use output size of linear layer the same as embedding size.

self.linear = nn.Linear(dim_out * num_heads, dim_inp) = nn.Linear(4 * 3, 4)

So, input size of linear layer is 12, output is 4.

The forward method has the same arguments as the AttentionHead.

def forward(self, input_tensor: size (2 x 5 x 4), attention_mask: size (2 x 1 x 5)):

At the first operation we calculate the list of attentions s.

s = [head(input_tensor, attention_mask) for head in self.heads]
s = [
    tensor(2 x 5 x 6),
    tensor(2 x 5 x 6),
    tensor(2 x 5 x 6),
]

Further, we concatenate tensors by the last axis.

scores = torch.cat(s, dim=-1) = tensor(2 x 5 x 18)

Pass scores through linear layer and normalize.

scores = self.linear(scores) = tensor(2 x 5 x 4)
return self.norm(scores)

Encoder

Encoder consists of multi-head attention and feed-forward neural network. In original Attention Is All You Need, the stack of identical encoder layers is used. We use only one in this tutorial because of lazyness for simplicity.
encoder_diagram
Encoder layer diagram from [vaswani et al, 2017].

class Encoder(nn.Module):  
  
    def __init__(self, dim_inp, dim_out, attention_heads=4, dropout=0.1):  
        super(Encoder, self).__init__()  
  
        self.attention = MultiHeadAttention(attention_heads, dim_inp, dim_out) 
        self.feed_forward = nn.Sequential(  
            nn.Linear(dim_inp, dim_out),  
            nn.Dropout(dropout),  
            nn.GELU(),  
            nn.Linear(dim_out, dim_inp),  
            nn.Dropout(dropout)  
        )
        self.norm = nn.LayerNorm(dim_inp)  
  
    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):  
        context = self.attention(input_tensor, attention_mask)  
        res = self.feed_forward(context)  
        return self.norm(res)

Feed forward network should be explained such as it’s slightly different than the one you may see in Attention Is All You Need.

self.feed_forward = nn.Sequential(  
    nn.Linear(dim_inp, dim_out),  
    nn.Dropout(dropout),  
    nn.GELU(),
    nn.Linear(dim_out, dim_inp),
    nn.Dropout(dropout)  
)

Original encoder has RelU as activation function. We use GelU (read more Gaussian Error Linear Units (Gelus)).

gelu
Visualization of GelU function comparing to RelU and ELU from [hendrycks et al, 2016].

The formula representation of our feed forward network.

Also, we add dropout layer after each linear.

Why we use GelU? Just because it give better results. You may follow Searching for Activation Functions paper for more details.

The forward method does simple:

  • Calculate attention context
  • Pass context through feed forward network
  • Normalize
def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):  
    context = self.attention(input_tensor, attention_mask)  
    res = self.feed_forward(context)  
    return self.norm(res)

BERT

BERT module is a container that combines all the modules together and returns the output.

class BERT(nn.Module):  
  
    def __init__(self, vocab_size, dim_inp, dim_out, attention_heads=4):  
        super(BERT, self).__init__()  
  
        self.embedding = JointEmbedding(vocab_size, dim_inp)  
        self.encoder = Encoder(dim_inp, dim_out, attention_heads)  
  
        self.token_prediction_layer = nn.Linear(dim_inp, vocab_size)  
        self.softmax = nn.LogSoftmax(dim=-1)  
        self.classification_layer = nn.Linear(dim_inp, 2)  
  
    def forward(self, input_tensor: torch.Tensor, attention_mask: torch.Tensor):  
        embedded = self.embedding(input_tensor)  
        encoded = self.encoder(embedded, attention_mask)  
  
        token_predictions = self.token_prediction_layer(encoded)  
  
        first_word = encoded[:, 0, :]  
        return self.softmax(token_predictions), self.classification_layer(first_word)

We use linear layer (and softmax) with output that is equal to the vocabulary size for token prediction task.

self.token_prediction_layer = nn.Linear(dim_inp, vocab_size)
self.softmax = nn.LogSoftmax(dim=-1)

And linear layer with output of 2 for next sentence prediction task

self.classification_layer = nn.Linear(dim_inp,  2)

The output of the network

argmax(NSP output) = [1, 0] is NOT next sentence
argmax(NSP output) = [0, 1] is next sentence

Everything is simple in the forward. At first we calculate embedding, then pass embeddings to our encoder.

embedded = self.embedding(input_tensor)  
encoded = self.encoder(embedded, attention_mask)

Secondly, we calculate the model output.

token_predictions = self.token_prediction_layer(encoded)  
 
first_word = encoded[:, 0, :]
return self.softmax(token_predictions), self.classification_layer(first_word)

The full model graph is also available. To build the graph, run the script graph.py. It saves the graph to data/logs directory. Run tensorboard

tensorboard --logdir data/logs

Open http://localhost:6006 in the browser, go to Graph tab. You should see the graph of our BERT model.

graph

Train the model

All training operations are in bert.trainer module in BertTrainer class. Let’s take a look at the class constructor.

class BertTrainer:  
  
    def __init__(self,  
                 model: BERT,  
                 dataset: IMDBBertDataset,  
                 log_dir: Path,  
                 checkpoint_dir: Path = None,  
                 print_progress_every: int = 10,  
                 print_accuracy_every: int = 50,  
                 batch_size: int = 24,  
                 learning_rate: float = 0.005,  
                 epochs: int = 5,  
                 ):  
        self.model = model  
        self.dataset = dataset  
  
        self.batch_size = batch_size  
        self.epochs = epochs  
        self.current_epoch = 0  
  
        self.loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True)  
  
        self.writer = SummaryWriter(str(log_dir))  
        self.checkpoint_dir = checkpoint_dir  
  
        self.criterion = nn.BCEWithLogitsLoss().to(device)  
        self.ml_criterion = nn.NLLLoss(ignore_index=0).to(device)  
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.015)

As you see, we define model to train, data loader to use, and log writer. We use TensorBoard to log the training progress. Read Visualizing Models, Data, And Training With Tensorboard to learn more how to use Tensorboard with pyTorch.

The most important part in the constructor are losses and optimizer definitions.

self.criterion = nn.BCEWithLogitsLoss().to(device)  
self.ml_criterion = nn.NLLLoss(ignore_index=0).to(device)

In order to train NSP task we use Sigmoid Binary Cross Entropy Loss. To train MLM we use Negative Log Likelihood. We use Adam optimizer.

The training process occurs in the train method.

def train(self, epoch: int):  
    print(f"Begin epoch {epoch}")  
  
    prev = time.time()  
    average_nsp_loss = 0  
    average_mlm_loss = 0  
    for i, value in enumerate(self.loader):  
        index = i + 1  
        inp, mask, inverse_token_mask, token_target, nsp_target = value  
        self.optimizer.zero_grad()  
  
        token, nsp = self.model(inp, mask)  
  
        tm = inverse_token_mask.unsqueeze(-1).expand_as(token)  
        token = token.masked_fill(tm, 0)  
  
        loss_token = self.ml_criterion(token.transpose(1, 2), token_target)
        loss_nsp = self.criterion(nsp, nsp_target)  
  
        loss = loss_token + loss_nsp  
        average_nsp_loss += loss_nsp  
        average_mlm_loss += loss_token  
  
        loss.backward()  
        self.optimizer.step()  
  
        if index % self._print_every == 0:  
            elapsed = time.gmtime(time.time() - prev)  
            s = self.training_summary(elapsed, index, average_nsp_loss, average_mlm_loss)  
  
            if index % self._accuracy_every == 0:  
                s += self.accuracy_summary(index, token, nsp, token_target, nsp_target)  
  
            print(s)  
  
            average_nsp_loss = 0  
  average_mlm_loss = 0  
  return loss

Let’s review the training step.

inp, mask, inverse_token_mask, token_target, nsp_target = value  
self.optimizer.zero_grad()  

First thing we do is retrieve batch data from loader. Then we set gradients to 0. Good answer why we do call zero_grad method is posted in StackOverflow.

In PyTorch, for every mini-batch during the training phase, we typically want to explicitly set the gradients to zero before starting to do backpropragation (i.e., updating the Weights and biases) because PyTorch accumulates the gradients on subsequent backward passes. This accumulating behaviour is convenient while training RNNs or when we want to compute the gradient of the loss summed over multiple mini-batches. So, the default action has been set to accumulate (i.e. sum) the gradients on every loss.backward() call.

Calculate forward step.

token, nsp = self.model(inp, mask)

Then we shadow other than [MASK] tokens in model output token. Reason for that is that we train model to predict only [MASK] tokens.

tm = inverse_token_mask.unsqueeze(-1).expand_as(token)  
token = token.masked_fill(tm, 0)  

Further we calculate criterions and sum the loss.

loss_token = self.ml_criterion(token.transpose(1, 2), token_target)
loss_nsp = self.criterion(nsp, nsp_target)  
  
loss = loss_token + loss_nsp  
average_nsp_loss += loss_nsp  
average_mlm_loss += loss_token  

Make backward step and update weights.

loss.backward()  
self.optimizer.step()  

From time to time, we calculate model’s accuracy

if index % self._accuracy_every == 0:  
    s += self.accuracy_summary(index, token, nsp, token_target, nsp_target) 

It calculates MLM and NSP accuracies

nsp_acc = nsp_accuracy(nsp, nsp_target)  
token_acc = token_accuracy(token, token_target, inverse_token_mask)

For NSP we calculate how much tensors in a batch were predicted correctly.

def nsp_accuracy(result: torch.Tensor, target: torch.Tensor):
    s = (result.argmax(1) == target.argmax(1)).sum()  
    return round(float(s / result.size(0)), 2)

For MLM we should do some manipulations — apply masking to the model output and target. Or as we do in the code, just select masked tokens and compare.

def token_accuracy(result: torch.Tensor, target: torch.Tensor, inverse_token_mask: torch.Tensor):
    r = result.argmax(-1).masked_select(~inverse_token_mask)  
    t = target.masked_select(~inverse_token_mask)  
    s = (r == t).sum()  
    return round(float(s / (result.size(0) * result.size(1))), 2)

Training Results and Summary

Finally, we are ready to run the training of our model. Long story short, open the main.py script file, check the learning parameters and run.

I trained the model on nVidia GeForce 1050ti GPU. If cuda is supported, the model will be trained on GPU by default. The next parameters of model were used

EMB_SIZE = 64  
HIDDEN_SIZE = 36  
EPOCHS = 4  
BATCH_SIZE = 12  
NUM_HEADS = 4

Embedding size is 64, hidden attention context size is 36, batch size is 12, number of attention heads is 4, and number of encoders is 1. Learning rate is 7e-5.

We use TensorBoard to track the training process.

After you run the training script, you should see how it prepares the IMDB dataset

Prepare dataset
Create vocabulary
100%|██████████| 491161/491161 [00:05<00:00, 93957.36it/s]
Preprocessing dataset
100%|██████████| 50000/50000 [00:35<00:00, 1407.99it/s]

Then the trainer prints the model summary

Model Summary

===================================
Device: cuda
Training dataset len: 882322
Max / Optimal sentence len: 27
Vocab size: 71942
Batch size: 12
Batched dataset len: 73526
===================================

And the training begins

Begin epoch 0
00:00:02 | Epoch 1 | 20 / 73526 (0.03%) | NSP loss   0.72 | MLM loss  11.25
00:00:04 | Epoch 1 | 40 / 73526 (0.05%) | NSP loss   0.70 | MLM loss  11.22
00:00:06 | Epoch 1 | 60 / 73526 (0.08%) | NSP loss   0.70 | MLM loss  11.13
00:00:08 | Epoch 1 | 80 / 73526 (0.11%) | NSP loss   0.71 | MLM loss  11.13
00:00:11 | Epoch 1 | 100 / 73526 (0.14%) | NSP loss   0.69 | MLM loss  11.05
00:00:13 | Epoch 1 | 120 / 73526 (0.16%) | NSP loss   0.70 | MLM loss  10.98
00:00:15 | Epoch 1 | 140 / 73526 (0.19%) | NSP loss   0.69 | MLM loss  10.95
00:00:18 | Epoch 1 | 160 / 73526 (0.22%) | NSP loss   0.70 | MLM loss  10.90
00:00:20 | Epoch 1 | 180 / 73526 (0.24%) | NSP loss   0.71 | MLM loss  10.89
00:00:22 | Epoch 1 | 200 / 73526 (0.27%) | NSP loss   0.72 | MLM loss  10.83 | NSP accuracy 0.25 | Token accuracy 0.01

The BERT model and even our too much simplified BERT model converges slowly and requires a lot of computational resources. I was able to train only one epoch. That took a little bit more than two hours

02:20:49 | Epoch 1 | 73440 / 73526 (99.88%) | NSP loss   0.69 | MLM loss   4.49
02:20:52 | Epoch 1 | 73460 / 73526 (99.91%) | NSP loss   0.69 | MLM loss   4.37
02:20:54 | Epoch 1 | 73480 / 73526 (99.94%) | NSP loss   0.69 | MLM loss   4.24
02:20:56 | Epoch 1 | 73500 / 73526 (99.96%) | NSP loss   0.69 | MLM loss   4.38
02:20:59 | Epoch 1 | 73520 / 73526 (99.99%) | NSP loss   0.70 | MLM loss   4.37

Let’s take a look how the loss value was changed throughout a time

mlm_loss
Change of MLM loss with time.

You may see that our BERT model’s loss really converges to some minimum but this process is really slow. For example here is the log message on 44’s % of data processed

01:03:01 | Epoch 1 | 32880 / 73526 (44.72%) | NSP loss   0.69 | MLM loss   4.78

And the message on 100’s % of data processed

02:20:59 | Epoch 1 | 73520 / 73526 (99.99%) | NSP loss   0.70 | MLM loss   4.37

For an hour of training the NSP loss was reduced only on some tenth parts.

nsp_loss
Change of NSP loss with time.

From the chart above you could say that NSP loss doesn’t converge but diverges. It converges but even slower than MLM. We can see this if we apply smoothing to the values of this chart

nsp_smoothed_loss
Smoothed change of NSP loss with time.

I would say we have such result because of our dataset. We use IMDB reviews for training and splitted text on sentences by . symbol. Now, I ask you to take a look what these sentences are. Noticed? So, it’s hard for the model to have a good catch on the data to solve this task. Original BERT used English Wikipedia and Books Corpus with good, long, and informative sentences.

Let’s take a look how train accuracies changes throughout a time.

mlm_accuracy
Change of MLM training accuracy.

nsp_accuracy
Change of NSP training accuracy.

The accuracy actually correlate with the loss. When MLM loss is slightly reduced, MLM accuracy is slightly improved. NSP accuracy is even more wiggy and after the first epoch is slightly more than 0.5 in average. The conclusion is that we definitely need to try different dataset. But anyway that is still good results for a tutorial :)

The model built in this tutorial is not the full BERT. In best words it’s just simplified version of BERT that is good only to understand its architecture, how it works. There are a lot of pre-trained BERT (and its variants) models build by HuggingFace. Now, you should have the understanding of how to build BERT from scratch (with pyTorch of course). Further, you can try to use different datasets and model parameters in order to see if it gives better results of tasks, especially, NSP task convergence.

Subscribe For Our Newsletters

Our latest content delivered to your inbox a few times a month

Thank you! Your submission has been received!
Oops! Something went wrong

Related Content:

No items found.