Python - NLP with BERT

Python
NLP
2021
Author

Cliff Weaver

Published

March 27, 2021

Sentiment Analysis Using BERT

Sentiment Analysis Using BERT

Objective

In this notebook, we will fine-tune a pre-trained BERT model to perform sentiment analysis on a twitter data.

Setup

Using Google Colab for training

Google Colab offers free GPUs and TPUs! Since we are going to train a large neural network, it’s best to take advantage of the GPU/TPU (in this case we’ll attach a GPU), otherwise training will take a very long time.

A GPU can be added by going to the menu and selecting:

Runtime -> Change Runtime -> GPU

Ths version of the document did not leverage Colab.

We will identify and specify the GPU as the device. Later, in our training loop, we will load data onto the device.

Code
#import torch library
import torch

# check GPU availability
if torch.cuda.is_available():    
    # select GPU    
    device = torch.device("cuda")

device
device(type='cuda')
Code
# check GPU name
torch.cuda.get_device_name(0)
'Tesla T4'

Installing Hugging Face’s Transformers Library

Hugging Face 🤗 is the one of the most popular Natural Language Processing communities for deep learning researchers, hands-on practitioners and educators. It provides State of Art architectures for everyone.

The Transformers library (formerly known as pytorch-transformers) provides a wide range of general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, etc) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with a wide range of pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch.

Code
#install hugging face transformers
!pip install transformers
Collecting transformers
  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
     |████████████████████████████████| 778kB 5.5MB/s eta 0:00:01
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)
Collecting sentencepiece!=0.1.92
  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
     |████████████████████████████████| 1.1MB 34.9MB/s 
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
Collecting sacremoses
  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
     |████████████████████████████████| 890kB 43.0MB/s 
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)
Collecting tokenizers==0.8.1.rc1
  Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
     |████████████████████████████████| 3.0MB 53.3MB/s 
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (1.15.0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.6.20)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.16.0)
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... done
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893257 sha256=96a7895687fabdc99603d071c0e4c96c12b8d225891507146da1692156aef9f6
  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45
Successfully built sacremoses
Installing collected packages: sentencepiece, sacremoses, tokenizers, transformers
Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.8.1rc1 transformers-3.0.2

Loading & Understanding BERT

Download Pretrained BERT model

We will use the uncased pre-trained version of the BERT base model. It was trained on lower-cased English text.

You can find more pre-trained models here https://huggingface.co/transformers/pretrained_models.html

Code
from transformers import BertModel

# download bert pretrained model
bert = BertModel.from_pretrained('bert-base-uncased')
Code
# print bert architecture
print(bert)
BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (2): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (3): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (4): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (5): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (6): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (7): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (8): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (9): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (10): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (11): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
  )
)

Tokenization and Input Formatting

Download BERT Tokenizer

Code
#importing fast "BERT" tokenizer
from transformers import BertTokenizerFast

# Load BERT tokenizer
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased', do_lower_case=True)

Steps Followed for Input Formatting

  1. Tokenization

  2. Special Tokens

  • Prepend the [CLS] token to the start of the sequence.
  • Append the [SEP] token to the end of the sequence.
  1. Pad sequences

  2. Converting tokens to integers

  3. Create Attention masks to avoid pad tokens

Code
#input text
text = "Jim Henson was a puppeteer"

sent_id = tokenizer.encode(text, 
                           # add [CLS] and [SEP] tokens
                           add_special_tokens=True,
                           # specify maximum length for the sequences                                  
                           max_length = 10,
                           truncation = True,
                           # add pad tokens to the right side of the sequence
                           pad_to_max_length='right')
                           
# print integer sequence
print("Integer Sequence: {}".format(sent_id))
Integer Sequence: [101, 3958, 27227, 2001, 1037, 13997, 11510, 102, 0, 0]
Code
# convert integers back to text
print("Tokenized Text:",tokenizer.convert_ids_to_tokens(sent_id))
Tokenized Text: ['[CLS]', 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]', '[PAD]', '[PAD]']
Code
# decode the tokenized text
decoded = tokenizer.decode(sent_id)
print("Decoded String: {}".format(decoded))
Decoded String: [CLS] jim henson was a puppeteer [SEP] [PAD] [PAD]
Code
# mask to avoid performing attention on padding token indices. 
# mask values: 1 for tokens that are NOT MASKED, 0 for MASKED tokens.   
att_mask = [int(tok > 0) for tok in sent_id]

print("Attention Mask:",att_mask)
Attention Mask: [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]

Understanding Input and Output

Code
# convert lists to tensors
sent_id = torch.tensor(sent_id)
att_mask = torch.tensor(att_mask)

# reshaping tensor in form of (batch,text length)
sent_id = sent_id.unsqueeze(0)
att_mask = att_mask.unsqueeze(0)

# reshaped tensor
print(sent_id)
tensor([[  101,  3958, 27227,  2001,  1037, 13997, 11510,   102,     0,     0]])
Code
# pass integer sequence to bert model
outputs = bert(sent_id, attention_mask=att_mask)  
Code
## unpack the ouput of bert model

# hidden states at each timestep
all_hidden_states = outputs[0]
# hidden states at first timestep ([CLS] token)
cls_hidden_state = outputs[1]

print("Shape of last hidden states:",all_hidden_states.shape)
print("Shape of CLS hidden state:",cls_hidden_state.shape)
Shape of last hidden states: torch.Size([1, 10, 768])
Shape of CLS hidden state: torch.Size([1, 768])
Code
cls_hidden_state
tensor([[-0.8767, -0.4109, -0.1220,  0.4494,  0.1945, -0.2698,  0.8316,  0.3127,
          0.1178, -1.0000, -0.1561,  0.6677,  0.9891, -0.3451,  0.8812, -0.6753,
         -0.3079, -0.5580,  0.4380, -0.4588,  0.5831,  0.9956,  0.4467,  0.2863,
          0.3924,  0.6863, -0.7513,  0.9043,  0.9436,  0.8207, -0.6493,  0.3524,
         -0.9919, -0.2295, -0.0742, -0.9936,  0.3698, -0.7558,  0.0792, -0.2218,
         -0.8637,  0.4711,  0.9997, -0.4368,  0.0404, -0.3498, -1.0000,  0.2663,
         -0.8711,  0.0508,  0.0505, -0.1635,  0.1716,  0.4363,  0.4330, -0.0333,
         -0.0416,  0.2206, -0.2568, -0.6122, -0.5916,  0.2569, -0.2622, -0.9041,
          0.3221, -0.2394, -0.2634, -0.3454, -0.0723,  0.0081,  0.8297,  0.2279,
          0.1614, -0.6555, -0.2062,  0.3280, -0.4016,  1.0000, -0.0952, -0.9874,
         -0.0401,  0.0717,  0.3675,  0.3373, -0.3710, -1.0000,  0.4479, -0.1722,
         -0.9917,  0.2677,  0.4844, -0.2207, -0.3207,  0.3715, -0.2171, -0.2522,
         -0.3071, -0.3161, -0.1988, -0.0860, -0.0114, -0.1982, -0.1799, -0.3221,
          0.1751, -0.4442, -0.1570, -0.0434, -0.0893,  0.5717,  0.3112, -0.2900,
          0.3305, -0.9430,  0.6061, -0.2984, -0.9873, -0.3956, -0.9926,  0.7857,
         -0.1692, -0.2719,  0.9505,  0.5628,  0.2904, -0.1693,  0.1619, -1.0000,
         -0.1696, -0.1534,  0.2513, -0.2857, -0.9846, -0.9638,  0.5565,  0.9200,
          0.1805,  0.9995, -0.2122,  0.9391,  0.3246, -0.3937, -0.1248, -0.5209,
          0.0519,  0.1141, -0.6463,  0.3529, -0.0322, -0.3837, -0.3796, -0.2830,
          0.1280, -0.9191, -0.4201,  0.9145,  0.0713, -0.2455,  0.5212, -0.2642,
         -0.3675,  0.8082,  0.2577,  0.2755, -0.0157,  0.3675, -0.3107,  0.4502,
         -0.8224,  0.2841,  0.4360, -0.3193,  0.2165, -0.9851, -0.4444,  0.5759,
          0.9878,  0.7531,  0.3384,  0.2003, -0.2602,  0.4695, -0.9561,  0.9855,
         -0.1712,  0.2295,  0.1220, -0.1386, -0.8436, -0.3783,  0.8371, -0.3204,
         -0.8457, -0.0473, -0.4219, -0.3593, -0.2186,  0.5282, -0.3149, -0.4375,
         -0.0440,  0.9242,  0.9296,  0.7735, -0.3733,  0.3945, -0.9049, -0.2898,
          0.2695,  0.2910,  0.1695,  0.9932, -0.3069, -0.1611, -0.8349, -0.9827,
          0.1299, -0.8555, -0.0531, -0.6830,  0.3926,  0.2873, -0.1899,  0.2598,
         -0.9201, -0.7455,  0.3943, -0.3955,  0.4015, -0.2341,  0.7593,  0.3421,
         -0.6143,  0.5170,  0.8987,  0.1072, -0.6858,  0.6481, -0.2454,  0.8712,
         -0.5958,  0.9936,  0.3404,  0.4972, -0.9452, -0.2347, -0.8748, -0.0154,
         -0.1293, -0.5265,  0.4235,  0.4206,  0.3663,  0.7488, -0.4650,  0.9900,
         -0.8695, -0.9701, -0.5203, -0.0900, -0.9914,  0.0978,  0.2844, -0.0424,
         -0.4649, -0.4546, -0.9620,  0.8035,  0.2177,  0.9705, -0.0793, -0.7985,
         -0.3436, -0.9537, -0.0035, -0.0945,  0.4291,  0.0391, -0.9602,  0.4497,
          0.5135,  0.4913,  0.0608,  0.9948,  1.0000,  0.9810,  0.8865,  0.7961,
         -0.9894, -0.5122,  1.0000, -0.8521, -1.0000, -0.9412, -0.6633,  0.3110,
         -1.0000, -0.1468, -0.1235, -0.9465, -0.0891,  0.9796,  0.9700, -1.0000,
          0.9324,  0.9259, -0.4503,  0.4591, -0.1785,  0.9819,  0.2285,  0.4423,
         -0.2615,  0.4124, -0.5252, -0.8534,  0.0365, -0.0670,  0.8944,  0.1913,
         -0.4782, -0.9402,  0.2293, -0.1581, -0.2440, -0.9604, -0.1924, -0.0555,
          0.5484,  0.1915,  0.2038, -0.7367,  0.2698, -0.7307,  0.3715,  0.5640,
         -0.9386, -0.5717,  0.3818, -0.2775,  0.1536, -0.9608,  0.9702, -0.3502,
          0.1524,  1.0000,  0.3876, -0.9001,  0.2547,  0.1857,  0.0832,  1.0000,
          0.3811, -0.9852, -0.4053,  0.2576, -0.3923, -0.4125,  0.9994, -0.1463,
         -0.0428,  0.2818,  0.9899, -0.9923,  0.8351, -0.8563, -0.9634,  0.9617,
          0.9268, -0.4224, -0.7369,  0.1318,  0.1107,  0.2294, -0.8914,  0.6082,
          0.4665, -0.0720,  0.8555, -0.7973, -0.3478,  0.4201, -0.1762,  0.0761,
          0.2823,  0.4571, -0.1350,  0.1190, -0.3509, -0.4039, -0.9556,  0.0262,
          1.0000, -0.2164,  0.0569, -0.2296, -0.1003, -0.1827,  0.4036,  0.4715,
         -0.3293, -0.8471, -0.0518, -0.8453, -0.9935,  0.6732,  0.2284, -0.1968,
          0.9998,  0.5194,  0.2326,  0.1718,  0.7497, -0.0192,  0.4518, -0.0327,
          0.9765, -0.3259,  0.3491,  0.7471, -0.3186, -0.3019, -0.5725,  0.0563,
         -0.9206,  0.0572, -0.9589,  0.9565,  0.3109,  0.3348,  0.1635, -0.0619,
          1.0000, -0.6020,  0.5309, -0.3723,  0.6636, -0.9851, -0.6789, -0.4312,
         -0.1435, -0.0827, -0.2497,  0.1323, -0.9786, -0.0474, -0.0304, -0.9444,
         -0.9927,  0.2508,  0.6172,  0.1679, -0.7980, -0.6078, -0.4906,  0.4646,
         -0.1934, -0.9396,  0.5453, -0.3000,  0.4329, -0.3340,  0.4408, -0.2058,
          0.8344,  0.1265, -0.0307, -0.2098, -0.8340,  0.7114, -0.7410,  0.0518,
         -0.1481,  1.0000, -0.3100,  0.1461,  0.7011,  0.6334, -0.2857,  0.1618,
          0.0966,  0.2955, -0.0981, -0.1832, -0.6208, -0.3013,  0.4337,  0.0283,
         -0.2959,  0.7579,  0.4711,  0.3666, -0.0531,  0.0914,  0.9969, -0.2267,
         -0.1165, -0.5533, -0.1262, -0.3575, -0.2124,  1.0000,  0.3679,  0.0604,
         -0.9936, -0.1999, -0.9208,  0.9999,  0.8511, -0.8783,  0.5650,  0.2405,
         -0.2859,  0.6935, -0.2598, -0.2655,  0.2893,  0.2862,  0.9774, -0.4575,
         -0.9764, -0.5964,  0.3966, -0.9575,  0.9939, -0.5326, -0.2349, -0.4376,
         -0.0250,  0.2574,  0.0274, -0.9762, -0.1582,  0.1821,  0.9811,  0.3014,
         -0.3820, -0.9007, -0.1151,  0.3936, -0.0680, -0.9449,  0.9809, -0.9313,
          0.2600,  1.0000,  0.3860, -0.5243,  0.2401, -0.4410,  0.3253, -0.1412,
          0.5428, -0.9466, -0.2817, -0.3262,  0.4330, -0.2120, -0.2457,  0.7247,
          0.2134, -0.3430, -0.6305, -0.1214,  0.4871,  0.7498, -0.2957, -0.1829,
          0.1699, -0.1391, -0.9264, -0.4167, -0.2995, -0.9991,  0.6411, -1.0000,
         -0.1510, -0.5473, -0.2219,  0.8075,  0.3862, -0.1392, -0.7206, -0.0710,
          0.6995,  0.6656, -0.2889,  0.2902, -0.6951,  0.1622, -0.1298,  0.3182,
          0.1694,  0.6526, -0.2735,  1.0000,  0.1370, -0.3043, -0.9189,  0.3041,
         -0.2604,  1.0000, -0.7969, -0.9715,  0.2110, -0.5773, -0.7218,  0.2477,
         -0.0304, -0.7015, -0.6577,  0.9111,  0.8219, -0.3693,  0.4537, -0.3062,
         -0.3671,  0.0856,  0.1595,  0.9903,  0.2790,  0.8213, -0.2885, -0.0723,
          0.9636,  0.2213,  0.6892,  0.2070,  1.0000,  0.3249, -0.8999,  0.2644,
         -0.9700, -0.2610, -0.9228,  0.4016,  0.1170,  0.8570, -0.3587,  0.9672,
          0.0667,  0.1108, -0.1840,  0.4711,  0.3127, -0.9391, -0.9892, -0.9908,
          0.3962, -0.5013, -0.0640,  0.3811,  0.1530,  0.4712,  0.3781, -1.0000,
          0.9466,  0.3529,  0.2077,  0.9735,  0.2019,  0.4726,  0.4248, -0.9892,
         -0.9203, -0.3418, -0.2910,  0.6572,  0.5584,  0.8190,  0.4319, -0.4171,
         -0.4697,  0.4653, -0.8583, -0.9940,  0.4802,  0.0740, -0.8986,  0.9559,
         -0.4745, -0.1616,  0.4457,  0.1412,  0.8933,  0.8280,  0.4313,  0.2437,
          0.6787,  0.9043,  0.8940,  0.9903, -0.2561,  0.6986, -0.0055,  0.3281,
          0.6809, -0.9586,  0.1583,  0.0033, -0.2711,  0.3025, -0.1928, -0.9207,
          0.5260, -0.2139,  0.5709, -0.2302,  0.1593, -0.4779, -0.1577, -0.7036,
         -0.5208,  0.4676,  0.2335,  0.9372,  0.4775, -0.1995, -0.5655, -0.2336,
          0.0798, -0.9315,  0.8288, -0.0946,  0.5294,  0.0223, -0.0744,  0.7821,
          0.1236, -0.3705, -0.3958, -0.7528,  0.8145, -0.3204, -0.4786, -0.5135,
          0.7306,  0.3208,  0.9981, -0.3959, -0.3492, -0.1118, -0.2872,  0.3596,
         -0.1345, -1.0000,  0.2896,  0.2262,  0.1702, -0.3530,  0.1111, -0.0755,
         -0.9565, -0.2658,  0.2530, -0.0490, -0.5834, -0.4616,  0.3937,  0.2329,
          0.5620,  0.8138, -0.0288,  0.5621,  0.3811,  0.0852, -0.6049,  0.8452]],
       grad_fn=<TanhBackward>)

Preparing Data

Loading and Reading Twitter Airline

Code
# upload the data and extract the file
!unzip 'Airline_Tweets.zip'
Archive:  Airline_Tweets.zip
  inflating: Tweets.csv              
Code
import pandas as pd

# increase the output column width
pd.set_option('display.max_colwidth', 200)

# read CSV file
df = pd.read_csv('Tweets.csv')

# print first 5 rows
df.head()
tweet_id airline_sentiment airline_sentiment_confidence negativereason negativereason_confidence airline airline_sentiment_gold name negativereason_gold retweet_count text tweet_coord tweet_created tweet_location user_timezone
0 570306133677760513 neutral 1.0000 NaN NaN Virgin America NaN cairdin NaN 0 @VirginAmerica What @dhepburn said. NaN 2015-02-24 11:35:52 -0800 NaN Eastern Time (US & Canada)
1 570301130888122368 positive 0.3486 NaN 0.0000 Virgin America NaN jnardino NaN 0 @VirginAmerica plus you've added commercials to the experience... tacky. NaN 2015-02-24 11:15:59 -0800 NaN Pacific Time (US & Canada)
2 570301083672813571 neutral 0.6837 NaN NaN Virgin America NaN yvonnalynn NaN 0 @VirginAmerica I didn't today... Must mean I need to take another trip! NaN 2015-02-24 11:15:48 -0800 Lets Play Central Time (US & Canada)
3 570301031407624196 negative 1.0000 Bad Flight 0.7033 Virgin America NaN jnardino NaN 0 @VirginAmerica it's really aggressive to blast obnoxious "entertainment" in your guests' faces &amp; they have little recourse NaN 2015-02-24 11:15:36 -0800 NaN Pacific Time (US & Canada)
4 570300817074462722 negative 1.0000 Can't Tell 1.0000 Virgin America NaN jnardino NaN 0 @VirginAmerica and it's a really big bad thing about it NaN 2015-02-24 11:14:45 -0800 NaN Pacific Time (US & Canada)
Code
#shape of the dataframe
df.shape
(14640, 15)
Code
df['text'].sample(5)
4837                                                                                @SouthwestAir how do I get a companion pass
2874                                                           @united Looks like they came through. Thanks again for the help.
8393                               @JetBlue no, but we're on the flight leaving from Boston to Seattle right now. :) flight 597
1624    “@united: @rikrik__ What made you come to this? Can we help you with anything? ^JP” the service just hasn't been great.
6882                                                                  @JetBlue - Definitely no note from whoever stole from me.
Name: text, dtype: object
Code
# class distribution
df['airline_sentiment'].value_counts()
negative    9178
neutral     3099
positive    2363
Name: airline_sentiment, dtype: int64
Code
# class distribution
df['airline_sentiment'].value_counts(normalize = True)
negative    0.626913
neutral     0.211680
positive    0.161407
Name: airline_sentiment, dtype: float64
Code
# saving the value counts to a list
class_counts = df['airline_sentiment'].value_counts().tolist()

Text Cleaning

Code
#library for pattern matching
import re

#define a function for text cleaning
def preprocessor(text):
  
  #convering text to lower case
  text = text.lower()

  #remove user mentions
  text = re.sub(r'@[A-Za-z0-9]+','',text)           
  
  #remove hashtags
  #text = re.sub(r'#[A-Za-z0-9]+','',text)         
  
  #remove links
  text = re.sub(r'http\S+', '', text)  
  
  #split token to remove extra spaces
  tokens = text.split()
  
  #join tokens by space
  return " ".join(tokens)
Code
# perform text cleaning
df['clean_text']= df['text'].apply(preprocessor)
Code
# save cleaned text and labels to a variable
text   = df['clean_text'].values
labels = df['airline_sentiment'].values
Code
#cleaned text
text[50:55]
array(['is flight 769 on it\'s way? was supposed to take off 30 minutes ago. website still shows "on time" not "in flight". thanks.',
       'julie andrews all the way though was very impressive! no to',
       'wish you flew out of atlanta... soon?',
       'julie andrews. hands down.',
       'will flights be leaving dallas for la on february 24th?'],
      dtype=object)

##3.3 Preparing Input and Output Data

Preparing Output

Code
#importing label encoder
from sklearn.preprocessing import LabelEncoder

#define label encoder
le = LabelEncoder()

#fit and transform target strings to a numbers
labels = le.fit_transform(labels)
Code
#classes
le.classes_
array(['negative', 'neutral', 'positive'], dtype=object)
Code
labels
array([1, 2, 1, ..., 1, 0, 1])

Preparing Input Data

Code
# library for visualization
import matplotlib.pyplot as plt

# compute no. of words in each tweet
num = [len(i.split()) for i in text]

plt.hist(num, bins = 30)

plt.title("Histogram: Length of sentences")
Text(0.5, 1.0, 'Histogram: Length of sentences')

Code
# define maximum length of a text
max_len = 25
Code
# library for progress bar
from tqdm import notebook

# create an empty list to save integer sequence
sent_id = []

# iterate over each tweet
for i in notebook.tqdm(range(len(text))):
  
  encoded_sent = tokenizer.encode(text[i],                      
                                  add_special_tokens = True,    
                                  max_length = max_len,
                                  truncation = True,         
                                  pad_to_max_length='right')    
  
  # saving integer sequence to a list
  sent_id.append(encoded_sent)
Code
print("Integer Sequence:",sent_id[0])
Integer Sequence: [101, 2054, 2056, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
Code
# create attention masks
attention_masks = []

# for each sentence...
for sent in sent_id:
  att_mask = [int(token_id > 0) for token_id in sent]
  
  # store the attention mask for this sentence.
  attention_masks.append(att_mask)

##3.4 Training and Validation Data

Code
# Use train_test_split to split our data into train and validation sets
from sklearn.model_selection import train_test_split

# Use 90% for training and 10% for validation.
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(sent_id, labels, random_state=2018, test_size=0.1, stratify=labels)

# Do the same for the masks.
train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.1, stratify=labels)

Define Dataloaders

Code
# Convert all inputs and labels into torch tensors, the required datatype for our model.
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)

train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)

train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)
Code
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

# The DataLoader needs to know our batch size for training, so we specify it here.
# For fine-tuning BERT on a specific task, the authors recommend a batch size of 16 or 32.

#define a batch size
batch_size = 32

# Create the DataLoader for our training set.
#Dataset wrapping tensors.
train_data = TensorDataset(train_inputs, train_masks, train_labels)

#define a sampler for sampling the data during training
  #random sampler samples randomly from a dataset 
  #sequential sampler samples sequentially, always in the same order
train_sampler = RandomSampler(train_data)

#represents a iterator over a dataset. Supports batching, customized data loading order
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# Create the DataLoader for our validation set.
#Dataset wrapping tensors.
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)

#define a sequential sampler 
#This samples data in a sequential order
validation_sampler = SequentialSampler(validation_data)

#create a iterator over the dataset
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)
Code
#create an iterator object
iterator = iter(train_dataloader)

#loads batch data
sent_id, mask, target=iterator.next()
Code
sent_id.shape
torch.Size([32, 25])
Code
sent_id
tensor([[  101,  5064,  2090,  1040,  2546,  2860,  1998,  8764,  1045,  2288,
         19030,  2013,  2260,  2497,  2035,  1996,  2126,  2000,  4601,  2290,
          2006, 20304,  2475,  1029,   102],
        [  101,  2129,  2055,  2070,  4086, 19430,  3642,  2000,  2393,  2033,
          3942,  2026,  2155,   999,   999,   999,  1012,  1012,  1012,  1012,
          3531,   999,   999,   999,   102],
        [  101,  7987, 27225,  1523,  1024,  2735,  2091,  2005,  2054,  1012,
          1001,  1058,  2595,  3736,  7959,  3723, 25514,  1524,   102,     0,
             0,     0,     0,     0,     0],
        [  101,  4931,  4364,  1010,  2339,  2106,  2026,  2197,  3462,  7796,
          2033,  1014, 19637,  1029,   102,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  3357,  1015,  1024,  2022,  2625,  2915,  1012,  3357,  1016,
          1024, 13399,  6304,  2060,  3182,  2084, 10474,  1012,  3357,  1017,
          1024,  2123,  1005,  1056,   102],
        [  101,  6146,  2581,  3823,  2013,  7921,  2012,  6921,  3199,  1999,
          2258,  2286,  1001, 20704, 18372,  2243,   102,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  3403,  2012,  6522,  2078,  2006,  4946,  2062,  2084,  2019,
          3178,  2144,  4899,  1038,  1013,  1039,  1997,  2053,  4796,  1010,
          3531,  2131,  2149,  2125,   102],
        [  101,  2024,  2017, 14763,  2005,  3462, 26727,  2157,  2085,   102,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  2003,  1996,  4037,  2091,  1029,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  1045,  1005,  1049,  2667,  2000,  4638,  2046,  2026,  2184,
          1024,  2753,  3286, 14931,  3462,  1056,  7382,  2006,  1996, 15363,
          4037,  1998,  2009,  1005,   102],
        [  101,  1040,  7583,  1996,  2171,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  1523,  1024,  3376,  2915,  1012,  1012,  4283,  2005,  6631,
          1012,  2478,  1001,  4875,  8873,  2000,  2695,  1029,  1025,  1007,
          1524,  2115,  6160,   999,   102],
        [  101,  2058,  1016,  2847,  1998,  2403,  2781,  2006,  2907,  1998,
         10320,  1012,  4283,  2005,  1996,  2393,   999,  1001,  6304,  2121,
          7903,  2063,  1001, 12077,   102],
        [  101,  1045,  2052,  2293,  2000,  2175,  2000,  1996,  5865,  2265,
          1625,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  2057,  1005,  2310,  1006,  3462,  4008, 23777,  1007,  2042,
          3564,  2006,  1996, 16985, 22911,  1997,  5887,  2050,  2005,  1037,
          2096,  2085,  3403,  2006,   102],
        [  101, 15536,  8873,  2347,  1005,  1056,  2551, 27120,  1012, 22333,
         16742,  2128, 22278,  1012,  2017,  2741,  2033,  2000,  1037,  3309,
          2005,  2484,  2847,  2007,   102],
        [  101,  2293,  2129,  2017,  2064,  1005,  1056,  2131,  2019,  4005,
          2006,  1996,  3042,  1998,  1996, 12978,  2291, 17991,  2039,  2006,
          2017,   102,     0,     0,     0],
        [  101,  1045,  4741,  2006,  2907,  2005,  2048,  2847,  1010,  2069,
          2000,  2031,  2026,  2655,  1012,  2428, 23579,  1012,   102,     0,
             0,     0,     0,     0,     0],
        [  101,  2748,  1012,  1045,  6618,  2019,  3178,  2001,  1037,  2146,
          2438,  2051,  2000,  2907,  2077,  3228,  2039,  1012,  2064,  8307,
          2655,  2033,  1029,   102,     0],
        [  101,  1011, 10885,  2420,  1024,  7882,  1010, 15544,  5753,  1999,
         10975,  2005,  1996,  2305,  1024,  1002,  6352,  1010,  3974,  1037,
          2154,  2000, 28781,  1998,   102],
        [  101, 13970, 12269,  2005,  2025,  8014,  3462,  2989,  7599,  2013,
          1040,  2546,  2860,  2023,  2851,  1012,  2142,  2788,  2034,  2000,
          6634,  1012,  1012,  1012,   102],
        [  101,  2044,  2035,  1045,  2031,  2042,  2083,  2006,  2023,  4440,
          1010,  2064,  2017,  2131,  2033,  2006,  2178,  8582,  2188,  1029,
           102,     0,     0,     0,     0],
        [  101,  4931,  2045,  2033,  2153,  2013,  7483, 10047,  2145,  2006,
          2907,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  2045,  2003,  2242,  3308,  2007,  2017,  4037,  1999, 23591,
         18059,   102,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  2003,  1996,  2190,  2126,  2000,  2128,  1011, 15908,  2033,
          2007,  2026,  2028,  2995,  2293,  1010,  6023,  1999,  3915,  1005,
          1055,  4827,  3007,  1001,   102],
        [  101,  4223,  2003,  2200,  4129,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  1040,  2386,  6914,  1012,  4012,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  2009,  1005,  1055,  9145,  2108,  2006,  2907,  2000,  2017,
           999,   999,  2428,  2423,  2781,  2000,  2689,  1037,  3462,   999,
           999,   102,     0,     0,     0],
        [  101,  6228,  3277,  1012,  3504,  2066,  2027,  2288,  2009,  4964,
           999,  4283,  2005,  2115,  5142,  1012,   102,     0,     0,     0,
             0,     0,     0,     0,     0],
        [  101,  2092,  2049,  2340,  1024,  3429,  3286,  1998,  2074,  2288,
          2019, 10373,  2008,  2026,  2340,  3286,  3462,  2003,  8394,  1011,
          2008,  2015,  2025,  2157,   102],
        [  101,  4067,  2017,  2005, 14120,  1012,  1012,  1012,  2026, 12191,
          2003,  1999,  2026,  4524,  1998,  1045,  2342,  2009,  2005,  2147,
          1012,  1045,  2572,  5191,   102],
        [  101,  7864,  1010,  5327,  2005,  4248,  3247,  1010,  5079,  2000,
          2178, 13445,  2057,  2074,  2363,  1024,  6378,  1011, 15045,  3296,
         16122,  8917,  1011,  5391,   102]])
Code
#pass inputs to the model
outputs = bert(sent_id,             #integer sequence
               attention_mask=mask) #attention masks
Code
# hidden states
hidden_states = outputs[0]

# [CLS] hidden state
CLS_hidden_state = outputs[1]

print("Shape of Hidden States:",hidden_states.shape)
print("Shape of CLS Hidden State:",CLS_hidden_state.shape)
Shape of Hidden States: torch.Size([32, 25, 768])
Shape of CLS Hidden State: torch.Size([32, 768])

Model Fine Tuning

The pretrained model is trained on the general domain corpus. So, finetuning the pretrained model helps in the capturing the domain specific features from our custom dataset

Every pretrained model is trained using 2 different layers : BackBone and Head

  • Backbone refers to the pretrained model architecture
  • Head refers to the dense layer added on top of backbone. Generally, this layer is used for the classification tasks.

Hence, we can finetune the pretrained model in 2 ways

  1. Fine-Tuning only Head (or Dense Layer) 1.1 CLS token 1.2 Hidden states

  2. Fine-Tuning both Backbone & Head 2.1 CLS token 2.2 Hidden states

Approach: Fine-Tuning Only Head

As the name suggests, in this approach, we freeze the backbone and train only the head or dense layer.

Steps to Follow

  1. Turn off Gradients
  2. Define Model Architecture
  3. Define Optimizer and Loss
  4. Define Train and Evaluate
  5. Train the model
  6. Evaluate the model
Code
# turn off the gradient of all the parameters
for param in bert.parameters():
    param.requires_grad = False

Define Model Architecture

Code
#importing nn module
import torch.nn as nn

class classifier(nn.Module):

    #define the layers and wrappers used by model
    def __init__(self, bert):
        
        #constructor
        super(classifier, self).__init__()
        
        #bert model
        self.bert = bert 
        
        # dense layer 1
        self.fc1 = nn.Linear(768,512)
        
        #dense layer 2 (Output layer)
        self.fc2 = nn.Linear(512,3)
        
        #dropout layer
        self.dropout = nn.Dropout(0.1)
        
        #relu activation function
        self.relu =  nn.ReLU()
        
        #softmax activation function
        self.softmax = nn.LogSoftmax(dim=1)
        
    #define the forward pass
    def forward(self, sent_id, mask):
        
        #pass the inputs to the model  
        all_hidden_states, cls_hidden_state = self.bert(sent_id, attention_mask=mask)
        
        #pass CLS hidden state to dense layer
        x = self.fc1(cls_hidden_state)
        
        #Apply ReLU activation function
        x = self.relu(x)
        
        #Apply Dropout
        x = self.dropout(x)
        
        #pass input to the output layer
        x = self.fc2(x)
        
        #apply softmax activation
        x = self.softmax(x)
        
    return x
Code
#create the model
model = classifier(bert)

#push the model to GPU, if available
model = model.to(device)
Code
#model architecture 
model
classifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (2): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (3): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (4): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (5): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (6): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (7): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (8): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (9): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (10): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (11): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
          )
          (output): BertOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (fc1): Linear(in_features=768, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=3, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (relu): ReLU()
  (softmax): LogSoftmax(dim=1)
)
Code
# push the tensors to GPU
sent_id = sent_id.to(device)
mask = mask.to(device)
target = target.to(device)

# pass inputs to the model
outputs = model(sent_id, mask)
Code
# understand outputs
print(outputs)
tensor([[-0.8960, -1.3895, -1.0712],
        [-0.8843, -1.2947, -1.1615],
        [-0.9099, -1.2691, -1.1509],
        [-0.8514, -1.2819, -1.2186],
        [-0.9650, -1.2425, -1.1076],
        [-0.9353, -1.3501, -1.0546],
        [-0.8627, -1.3478, -1.1452],
        [-0.9753, -1.2643, -1.0774],
        [-0.8849, -1.4208, -1.0620],
        [-0.9231, -1.2884, -1.1178],
        [-0.9613, -1.2478, -1.1073],
        [-0.9618, -1.2306, -1.1219],
        [-0.8743, -1.3402, -1.1361],
        [-0.9644, -1.2575, -1.0954],
        [-0.9228, -1.3100, -1.1003],
        [-0.8805, -1.2835, -1.1765],
        [-0.8361, -1.3504, -1.1793],
        [-0.8948, -1.3033, -1.1404],
        [-0.8801, -1.3966, -1.0853],
        [-0.9366, -1.2906, -1.0998],
        [-0.8680, -1.3155, -1.1652],
        [-0.8701, -1.2654, -1.2073],
        [-0.8920, -1.3223, -1.1281],
        [-0.8964, -1.3177, -1.1264],
        [-0.7724, -1.4183, -1.2175],
        [-0.9444, -1.3060, -1.0783],
        [-0.8942, -1.3986, -1.0668],
        [-0.9210, -1.2643, -1.1412],
        [-0.9566, -1.2440, -1.1161],
        [-0.9399, -1.2842, -1.1012],
        [-0.8473, -1.3528, -1.1617],
        [-0.9309, -1.3419, -1.0658]], device='cuda:0',
       grad_fn=<LogSoftmaxBackward>)
Code
# no. of trianable parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 395,267 trainable parameters

Define Optimizer and Loss function

Code
# Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
Code
# understand the class distribution
keys=['0','1','2']

# set figure size
plt.figure(figsize=(5,5))

# plot bat chart
plt.bar(keys,class_counts)

# set title
plt.title('Class Distribution')
Text(0.5, 1.0, 'Class Distribution')

Code
#library for array processing
import numpy as np

#library for computing class weights
from sklearn.utils.class_weight import compute_class_weight

#compute the class weights
class_weights = compute_class_weight('balanced', np.unique(labels), labels)

print("Class Weights:",class_weights)
Class Weights: [0.53170625 1.57470152 2.06517139]
Code
# converting a list of class weights to a tensor
weights= torch.tensor(class_weights,dtype=torch.float)

# transfer to GPU
weights = weights.to(device)

# define the loss function
cross_entropy  = nn.NLLLoss(weight=weights) 
Code
#compute the loss
loss = cross_entropy(outputs, target)
print("Loss:",loss)
Loss: tensor(1.1441, device='cuda:0', grad_fn=<NllLossBackward>)
Code
import time
import datetime

# compute time in hh:mm:ss
def format_time(elapsed):
    # round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    
    # format as hh:mm:ss
    return str(datetime.timedelta(seconds = elapsed_rounded))

Model Training and Evaluation

The deep learning model is trained in the form of epochs where in each epoch consists of several batches.

During training, for each batch, we need to

  1. Perform Forward Pass
  2. Compute Loss
  3. Backpropagate Loss
  4. Update Weights

Where as during evaluation, for each batch, we need to

  1. Perform forward pass
  2. Compute loss
Training: Epoch -> Batch -> Forward Pass -> Compute loss -> Backpropagate loss -> Update weights 
Evaluation: Epoch -> Batch -> Forward Pass -> Compute loss

Hence, for each epoch, we have a training phase and a validation phase. After each batch we need to:

Training phase

  1. Load data onto the GPU for acceleration
  2. Unpack our data inputs and labels
  3. Clear out the gradients calculated in the previous pass.
  4. Forward pass (feed input data through the network)
  5. Backward pass (backpropagation)
  6. Update parameters with optimizer.step()
  7. Track variables for monitoring progress
Code
#define a function for training the model
def train():
    print("\nTraining.....")  
    
    #set the model on training phase - Dropout layers are activated
    model.train()
    
    #record the current time
    t0 = time.time()
    
    #initialize loss and accuracy to 0
    total_loss, total_accuracy = 0, 0
    
    #Create a empty list to save the model predictions
    total_preds=[]
    
    #for every batch
    for step,batch in enumerate(train_dataloader):
        
        # Progress update after every 40 batches.
        if step % 40 == 0 and not step == 0:
            
            # Calculate elapsed time in minutes.
            elapsed = format_time(time.time() - t0)
            
    # Report progress.
    print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
    
    #push the batch to gpu
    batch = tuple(t.to(device) for t in batch)
    
    #unpack the batch into separate variables
    # `batch` contains three pytorch tensors:
    #   [0]: input ids 
    #   [1]: attention masks
    #   [2]: labels 
    sent_id, mask, labels = batch

    # Always clear any previously calculated gradients before performing a
    # backward pass. PyTorch doesn't do this automatically. 
    model.zero_grad()        

    # Perform a forward pass. This returns the model predictions
    preds = model(sent_id, mask)

    #compute the loss between actual and predicted values
    # loss =  cross_entropy(preds, labels)
    loss =  cross_entropy(preds, labels.long())

    # Accumulate the training loss over all of the batches so that we can
    # calculate the average loss at the end. `loss` is a Tensor containing a
    # single value; the `.item()` function just returns the Python value 
    # from the tensor.
    total_loss = total_loss + loss.item()

    # Perform a backward pass to calculate the gradients.
    loss.backward()

    # Update parameters and take a step using the computed gradient.
    # The optimizer dictates the "update rule"--how the parameters are
    # modified based on their gradients, the learning rate, etc.
    optimizer.step()

    #The model predictions are stored on GPU. So, push it to CPU
    preds=preds.detach().cpu().numpy()

    #Accumulate the model predictions of each batch
    total_preds.append(preds)

  #compute the training loss of a epoch
  avg_loss     = total_loss / len(train_dataloader)
  
  #The predictions are in the form of (no. of batches, size of batch, no. of classes).
  #So, reshaping the predictions in form of (number of samples, no. of classes)
  total_preds  = np.concatenate(total_preds, axis=0)

  #returns the loss and predictions
  return avg_loss, total_preds

Evaluation phase

  1. Load data onto the GPU for acceleration
  2. Unpack our data inputs and labels
  3. Forward pass (feed input data through the network)
  4. Compute loss on our validation data
  5. Track variables for monitoring progress
Code
#define a function for evaluating the model
def evaluate():
  
  print("\nEvaluating.....")
  
  #set the model on training phase - Dropout layers are deactivated
  model.eval()

  #record the current time
  t0 = time.time()

  #initialize the loss and accuracy to 0
  total_loss, total_accuracy = 0, 0
  
  #Create a empty list to save the model predictions
  total_preds = []

  #for each batch  
  for step,batch in enumerate(validation_dataloader):
    
    # Progress update every 40 batches.
    if step % 40 == 0 and not step == 0:
      
      # Calculate elapsed time in minutes.
      elapsed = format_time(time.time() - t0)
            
      # Report progress.
      print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(validation_dataloader), elapsed))

    #push the batch to gpu
    batch = tuple(t.to(device) for t in batch)

    #unpack the batch into separate variables
    # `batch` contains three pytorch tensors:
    #   [0]: input ids 
    #   [1]: attention masks
    #   [2]: labels        
    sent_id, mask, labels = batch

    #deactivates autograd
    with torch.no_grad():
      
      # Perform a forward pass. This returns the model predictions
      preds = model(sent_id, mask)

      #compute the validation loss between actual and predicted values
      loss = cross_entropy(preds,labels)

      # Accumulate the validation loss over all of the batches so that we can
      # calculate the average loss at the end. `loss` is a Tensor containing a
      # single value; the `.item()` function just returns the Python value 
      # from the tensor.      
      total_loss = total_loss + loss.item()

      #The model predictions are stored on GPU. So, push it to CPU
      preds=preds.detach().cpu().numpy()

      #Accumulate the model predictions of each batch
      total_preds.append(preds)

  #compute the validation loss of a epoch
  avg_loss = total_loss / len(validation_dataloader) 

  #The predictions are in the form of (no. of batches, size of batch, no. of classes).
  #So, reshaping the predictions in form of (number of samples, no. of classes)
  total_preds  = np.concatenate(total_preds, axis=0)

  return avg_loss, total_preds

Train the Model

Code
#Assign the initial loss to infinite
best_valid_loss = float('inf')

#create a empty list to store training and validation loss of each epoch
train_losses=[]
valid_losses=[]

epochs = 5

#for each epoch
for epoch in range(epochs):
     
    print('\n....... epoch {:} / {:} .......'.format(epoch + 1, epochs))
    
    #train model
    train_loss, _ = train()
    
    #evaluate model
    valid_loss, _ = evaluate()
    
    #save the best model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'saved_weights.pt')
    
    #accumulate training and validation loss
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    print(f'\nTraining Loss: {train_loss:.3f}')
    print(f'Validation Loss: {valid_loss:.3f}')

print("")
print("Training complete!")

....... epoch 1 / 5 .......

Training.....
  Batch    40  of    412.    Elapsed: 0:00:02.
  Batch    80  of    412.    Elapsed: 0:00:04.
  Batch   120  of    412.    Elapsed: 0:00:06.
  Batch   160  of    412.    Elapsed: 0:00:08.
  Batch   200  of    412.    Elapsed: 0:00:10.
  Batch   240  of    412.    Elapsed: 0:00:11.
  Batch   280  of    412.    Elapsed: 0:00:13.
  Batch   320  of    412.    Elapsed: 0:00:15.
  Batch   360  of    412.    Elapsed: 0:00:17.
  Batch   400  of    412.    Elapsed: 0:00:19.

Evaluating.....
  Batch    40  of     46.    Elapsed: 0:00:02.

Training Loss: 0.981
Validation Loss: 0.836

....... epoch 2 / 5 .......

Training.....
  Batch    40  of    412.    Elapsed: 0:00:02.
  Batch    80  of    412.    Elapsed: 0:00:04.
  Batch   120  of    412.    Elapsed: 0:00:06.
  Batch   160  of    412.    Elapsed: 0:00:08.
  Batch   200  of    412.    Elapsed: 0:00:10.
  Batch   240  of    412.    Elapsed: 0:00:12.
  Batch   280  of    412.    Elapsed: 0:00:14.
  Batch   320  of    412.    Elapsed: 0:00:16.
  Batch   360  of    412.    Elapsed: 0:00:18.
  Batch   400  of    412.    Elapsed: 0:00:20.

Evaluating.....
  Batch    40  of     46.    Elapsed: 0:00:02.

Training Loss: 0.851
Validation Loss: 0.715

....... epoch 3 / 5 .......

Training.....
  Batch    40  of    412.    Elapsed: 0:00:02.
  Batch    80  of    412.    Elapsed: 0:00:04.
  Batch   120  of    412.    Elapsed: 0:00:06.
  Batch   160  of    412.    Elapsed: 0:00:08.
  Batch   200  of    412.    Elapsed: 0:00:10.
  Batch   240  of    412.    Elapsed: 0:00:12.
  Batch   280  of    412.    Elapsed: 0:00:14.
  Batch   320  of    412.    Elapsed: 0:00:16.
  Batch   360  of    412.    Elapsed: 0:00:18.
  Batch   400  of    412.    Elapsed: 0:00:20.

Evaluating.....
  Batch    40  of     46.    Elapsed: 0:00:02.

Training Loss: 0.799
Validation Loss: 0.872

....... epoch 4 / 5 .......

Training.....
  Batch    40  of    412.    Elapsed: 0:00:02.
  Batch    80  of    412.    Elapsed: 0:00:04.
  Batch   120  of    412.    Elapsed: 0:00:06.
  Batch   160  of    412.    Elapsed: 0:00:08.
  Batch   200  of    412.    Elapsed: 0:00:10.
  Batch   240  of    412.    Elapsed: 0:00:12.
  Batch   280  of    412.    Elapsed: 0:00:14.
  Batch   320  of    412.    Elapsed: 0:00:16.
  Batch   360  of    412.    Elapsed: 0:00:18.
  Batch   400  of    412.    Elapsed: 0:00:20.

Evaluating.....
  Batch    40  of     46.    Elapsed: 0:00:02.

Training Loss: 0.780
Validation Loss: 0.679

....... epoch 5 / 5 .......

Training.....
  Batch    40  of    412.    Elapsed: 0:00:02.
  Batch    80  of    412.    Elapsed: 0:00:04.
  Batch   120  of    412.    Elapsed: 0:00:06.
  Batch   160  of    412.    Elapsed: 0:00:08.
  Batch   200  of    412.    Elapsed: 0:00:10.
  Batch   240  of    412.    Elapsed: 0:00:12.
  Batch   280  of    412.    Elapsed: 0:00:14.
  Batch   320  of    412.    Elapsed: 0:00:16.
  Batch   360  of    412.    Elapsed: 0:00:18.
  Batch   400  of    412.    Elapsed: 0:00:20.

Evaluating.....
  Batch    40  of     46.    Elapsed: 0:00:02.

Training Loss: 0.760
Validation Loss: 0.677

Training complete!

##4.6 Model Evaluation

Code
# load weights of best model
path='saved_weights.pt'
model.load_state_dict(torch.load(path))
<All keys matched successfully>
Code
# get the model predictions on the validation data
# returns 2 elements- Validation loss and Predictions
valid_loss, preds = evaluate()
print(valid_loss)

Evaluating.....
  Batch    40  of     46.    Elapsed: 0:00:02.
0.6771649757157201
Code
# Converting the log(probabities) into a classes
# Choosing index of a maximum value as class
y_pred = np.argmax(preds,axis=1)

# actual labels
y_true = validation_labels
Code
from sklearn.metrics import classification_report
print(classification_report(y_true,y_pred))
              precision    recall  f1-score   support

           0       0.91      0.68      0.78       918
           1       0.51      0.65      0.57       310
           2       0.54      0.86      0.66       236

    accuracy                           0.70      1464
   macro avg       0.65      0.73      0.67      1464
weighted avg       0.76      0.70      0.72      1464