Code
#import torch library
import torch
# check GPU availability
if torch.cuda.is_available():
# select GPU
= torch.device("cuda")
device
device
device(type='cuda')
Cliff Weaver
March 27, 2021
In this notebook, we will fine-tune a pre-trained BERT model to perform sentiment analysis on a twitter data.
Google Colab offers free GPUs and TPUs! Since we are going to train a large neural network, it’s best to take advantage of the GPU/TPU (in this case we’ll attach a GPU), otherwise training will take a very long time.
A GPU can be added by going to the menu and selecting:
Runtime -> Change Runtime -> GPU
Ths version of the document did not leverage Colab.
We will identify and specify the GPU as the device. Later, in our training loop, we will load data onto the device.
device(type='cuda')
Hugging Face 🤗 is the one of the most popular Natural Language Processing communities for deep learning researchers, hands-on practitioners and educators. It provides State of Art architectures for everyone.
The Transformers library (formerly known as pytorch-transformers) provides a wide range of general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, etc) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with a wide range of pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch.
Collecting transformers
Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
|████████████████████████████████| 778kB 5.5MB/s eta 0:00:01
Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)
Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)
Collecting sentencepiece!=0.1.92
Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
|████████████████████████████████| 1.1MB 34.9MB/s
Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)
Collecting sacremoses
Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
|████████████████████████████████| 890kB 43.0MB/s
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)
Collecting tokenizers==0.8.1.rc1
Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
|████████████████████████████████| 3.0MB 53.3MB/s
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (1.15.0)
Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.6.20)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.10)
Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)
Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.16.0)
Building wheels for collected packages: sacremoses
Building wheel for sacremoses (setup.py) ... done
Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893257 sha256=96a7895687fabdc99603d071c0e4c96c12b8d225891507146da1692156aef9f6
Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45
Successfully built sacremoses
Installing collected packages: sentencepiece, sacremoses, tokenizers, transformers
Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.8.1rc1 transformers-3.0.2
We will use the uncased pre-trained version of the BERT base model. It was trained on lower-cased English text.
You can find more pre-trained models here https://huggingface.co/transformers/pretrained_models.html
BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(2): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(3): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(4): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(5): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(6): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(7): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(8): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(9): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(10): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
Download BERT Tokenizer
Steps Followed for Input Formatting
Tokenization
Special Tokens
[CLS]
token to the start of the sequence.[SEP]
token to the end of the sequence.Pad sequences
Converting tokens to integers
Create Attention masks to avoid pad tokens
#input text
text = "Jim Henson was a puppeteer"
sent_id = tokenizer.encode(text,
# add [CLS] and [SEP] tokens
add_special_tokens=True,
# specify maximum length for the sequences
max_length = 10,
truncation = True,
# add pad tokens to the right side of the sequence
pad_to_max_length='right')
# print integer sequence
print("Integer Sequence: {}".format(sent_id))
Integer Sequence: [101, 3958, 27227, 2001, 1037, 13997, 11510, 102, 0, 0]
Tokenized Text: ['[CLS]', 'jim', 'henson', 'was', 'a', 'puppet', '##eer', '[SEP]', '[PAD]', '[PAD]']
Decoded String: [CLS] jim henson was a puppeteer [SEP] [PAD] [PAD]
tensor([[ 101, 3958, 27227, 2001, 1037, 13997, 11510, 102, 0, 0]])
Shape of last hidden states: torch.Size([1, 10, 768])
Shape of CLS hidden state: torch.Size([1, 768])
tensor([[-0.8767, -0.4109, -0.1220, 0.4494, 0.1945, -0.2698, 0.8316, 0.3127,
0.1178, -1.0000, -0.1561, 0.6677, 0.9891, -0.3451, 0.8812, -0.6753,
-0.3079, -0.5580, 0.4380, -0.4588, 0.5831, 0.9956, 0.4467, 0.2863,
0.3924, 0.6863, -0.7513, 0.9043, 0.9436, 0.8207, -0.6493, 0.3524,
-0.9919, -0.2295, -0.0742, -0.9936, 0.3698, -0.7558, 0.0792, -0.2218,
-0.8637, 0.4711, 0.9997, -0.4368, 0.0404, -0.3498, -1.0000, 0.2663,
-0.8711, 0.0508, 0.0505, -0.1635, 0.1716, 0.4363, 0.4330, -0.0333,
-0.0416, 0.2206, -0.2568, -0.6122, -0.5916, 0.2569, -0.2622, -0.9041,
0.3221, -0.2394, -0.2634, -0.3454, -0.0723, 0.0081, 0.8297, 0.2279,
0.1614, -0.6555, -0.2062, 0.3280, -0.4016, 1.0000, -0.0952, -0.9874,
-0.0401, 0.0717, 0.3675, 0.3373, -0.3710, -1.0000, 0.4479, -0.1722,
-0.9917, 0.2677, 0.4844, -0.2207, -0.3207, 0.3715, -0.2171, -0.2522,
-0.3071, -0.3161, -0.1988, -0.0860, -0.0114, -0.1982, -0.1799, -0.3221,
0.1751, -0.4442, -0.1570, -0.0434, -0.0893, 0.5717, 0.3112, -0.2900,
0.3305, -0.9430, 0.6061, -0.2984, -0.9873, -0.3956, -0.9926, 0.7857,
-0.1692, -0.2719, 0.9505, 0.5628, 0.2904, -0.1693, 0.1619, -1.0000,
-0.1696, -0.1534, 0.2513, -0.2857, -0.9846, -0.9638, 0.5565, 0.9200,
0.1805, 0.9995, -0.2122, 0.9391, 0.3246, -0.3937, -0.1248, -0.5209,
0.0519, 0.1141, -0.6463, 0.3529, -0.0322, -0.3837, -0.3796, -0.2830,
0.1280, -0.9191, -0.4201, 0.9145, 0.0713, -0.2455, 0.5212, -0.2642,
-0.3675, 0.8082, 0.2577, 0.2755, -0.0157, 0.3675, -0.3107, 0.4502,
-0.8224, 0.2841, 0.4360, -0.3193, 0.2165, -0.9851, -0.4444, 0.5759,
0.9878, 0.7531, 0.3384, 0.2003, -0.2602, 0.4695, -0.9561, 0.9855,
-0.1712, 0.2295, 0.1220, -0.1386, -0.8436, -0.3783, 0.8371, -0.3204,
-0.8457, -0.0473, -0.4219, -0.3593, -0.2186, 0.5282, -0.3149, -0.4375,
-0.0440, 0.9242, 0.9296, 0.7735, -0.3733, 0.3945, -0.9049, -0.2898,
0.2695, 0.2910, 0.1695, 0.9932, -0.3069, -0.1611, -0.8349, -0.9827,
0.1299, -0.8555, -0.0531, -0.6830, 0.3926, 0.2873, -0.1899, 0.2598,
-0.9201, -0.7455, 0.3943, -0.3955, 0.4015, -0.2341, 0.7593, 0.3421,
-0.6143, 0.5170, 0.8987, 0.1072, -0.6858, 0.6481, -0.2454, 0.8712,
-0.5958, 0.9936, 0.3404, 0.4972, -0.9452, -0.2347, -0.8748, -0.0154,
-0.1293, -0.5265, 0.4235, 0.4206, 0.3663, 0.7488, -0.4650, 0.9900,
-0.8695, -0.9701, -0.5203, -0.0900, -0.9914, 0.0978, 0.2844, -0.0424,
-0.4649, -0.4546, -0.9620, 0.8035, 0.2177, 0.9705, -0.0793, -0.7985,
-0.3436, -0.9537, -0.0035, -0.0945, 0.4291, 0.0391, -0.9602, 0.4497,
0.5135, 0.4913, 0.0608, 0.9948, 1.0000, 0.9810, 0.8865, 0.7961,
-0.9894, -0.5122, 1.0000, -0.8521, -1.0000, -0.9412, -0.6633, 0.3110,
-1.0000, -0.1468, -0.1235, -0.9465, -0.0891, 0.9796, 0.9700, -1.0000,
0.9324, 0.9259, -0.4503, 0.4591, -0.1785, 0.9819, 0.2285, 0.4423,
-0.2615, 0.4124, -0.5252, -0.8534, 0.0365, -0.0670, 0.8944, 0.1913,
-0.4782, -0.9402, 0.2293, -0.1581, -0.2440, -0.9604, -0.1924, -0.0555,
0.5484, 0.1915, 0.2038, -0.7367, 0.2698, -0.7307, 0.3715, 0.5640,
-0.9386, -0.5717, 0.3818, -0.2775, 0.1536, -0.9608, 0.9702, -0.3502,
0.1524, 1.0000, 0.3876, -0.9001, 0.2547, 0.1857, 0.0832, 1.0000,
0.3811, -0.9852, -0.4053, 0.2576, -0.3923, -0.4125, 0.9994, -0.1463,
-0.0428, 0.2818, 0.9899, -0.9923, 0.8351, -0.8563, -0.9634, 0.9617,
0.9268, -0.4224, -0.7369, 0.1318, 0.1107, 0.2294, -0.8914, 0.6082,
0.4665, -0.0720, 0.8555, -0.7973, -0.3478, 0.4201, -0.1762, 0.0761,
0.2823, 0.4571, -0.1350, 0.1190, -0.3509, -0.4039, -0.9556, 0.0262,
1.0000, -0.2164, 0.0569, -0.2296, -0.1003, -0.1827, 0.4036, 0.4715,
-0.3293, -0.8471, -0.0518, -0.8453, -0.9935, 0.6732, 0.2284, -0.1968,
0.9998, 0.5194, 0.2326, 0.1718, 0.7497, -0.0192, 0.4518, -0.0327,
0.9765, -0.3259, 0.3491, 0.7471, -0.3186, -0.3019, -0.5725, 0.0563,
-0.9206, 0.0572, -0.9589, 0.9565, 0.3109, 0.3348, 0.1635, -0.0619,
1.0000, -0.6020, 0.5309, -0.3723, 0.6636, -0.9851, -0.6789, -0.4312,
-0.1435, -0.0827, -0.2497, 0.1323, -0.9786, -0.0474, -0.0304, -0.9444,
-0.9927, 0.2508, 0.6172, 0.1679, -0.7980, -0.6078, -0.4906, 0.4646,
-0.1934, -0.9396, 0.5453, -0.3000, 0.4329, -0.3340, 0.4408, -0.2058,
0.8344, 0.1265, -0.0307, -0.2098, -0.8340, 0.7114, -0.7410, 0.0518,
-0.1481, 1.0000, -0.3100, 0.1461, 0.7011, 0.6334, -0.2857, 0.1618,
0.0966, 0.2955, -0.0981, -0.1832, -0.6208, -0.3013, 0.4337, 0.0283,
-0.2959, 0.7579, 0.4711, 0.3666, -0.0531, 0.0914, 0.9969, -0.2267,
-0.1165, -0.5533, -0.1262, -0.3575, -0.2124, 1.0000, 0.3679, 0.0604,
-0.9936, -0.1999, -0.9208, 0.9999, 0.8511, -0.8783, 0.5650, 0.2405,
-0.2859, 0.6935, -0.2598, -0.2655, 0.2893, 0.2862, 0.9774, -0.4575,
-0.9764, -0.5964, 0.3966, -0.9575, 0.9939, -0.5326, -0.2349, -0.4376,
-0.0250, 0.2574, 0.0274, -0.9762, -0.1582, 0.1821, 0.9811, 0.3014,
-0.3820, -0.9007, -0.1151, 0.3936, -0.0680, -0.9449, 0.9809, -0.9313,
0.2600, 1.0000, 0.3860, -0.5243, 0.2401, -0.4410, 0.3253, -0.1412,
0.5428, -0.9466, -0.2817, -0.3262, 0.4330, -0.2120, -0.2457, 0.7247,
0.2134, -0.3430, -0.6305, -0.1214, 0.4871, 0.7498, -0.2957, -0.1829,
0.1699, -0.1391, -0.9264, -0.4167, -0.2995, -0.9991, 0.6411, -1.0000,
-0.1510, -0.5473, -0.2219, 0.8075, 0.3862, -0.1392, -0.7206, -0.0710,
0.6995, 0.6656, -0.2889, 0.2902, -0.6951, 0.1622, -0.1298, 0.3182,
0.1694, 0.6526, -0.2735, 1.0000, 0.1370, -0.3043, -0.9189, 0.3041,
-0.2604, 1.0000, -0.7969, -0.9715, 0.2110, -0.5773, -0.7218, 0.2477,
-0.0304, -0.7015, -0.6577, 0.9111, 0.8219, -0.3693, 0.4537, -0.3062,
-0.3671, 0.0856, 0.1595, 0.9903, 0.2790, 0.8213, -0.2885, -0.0723,
0.9636, 0.2213, 0.6892, 0.2070, 1.0000, 0.3249, -0.8999, 0.2644,
-0.9700, -0.2610, -0.9228, 0.4016, 0.1170, 0.8570, -0.3587, 0.9672,
0.0667, 0.1108, -0.1840, 0.4711, 0.3127, -0.9391, -0.9892, -0.9908,
0.3962, -0.5013, -0.0640, 0.3811, 0.1530, 0.4712, 0.3781, -1.0000,
0.9466, 0.3529, 0.2077, 0.9735, 0.2019, 0.4726, 0.4248, -0.9892,
-0.9203, -0.3418, -0.2910, 0.6572, 0.5584, 0.8190, 0.4319, -0.4171,
-0.4697, 0.4653, -0.8583, -0.9940, 0.4802, 0.0740, -0.8986, 0.9559,
-0.4745, -0.1616, 0.4457, 0.1412, 0.8933, 0.8280, 0.4313, 0.2437,
0.6787, 0.9043, 0.8940, 0.9903, -0.2561, 0.6986, -0.0055, 0.3281,
0.6809, -0.9586, 0.1583, 0.0033, -0.2711, 0.3025, -0.1928, -0.9207,
0.5260, -0.2139, 0.5709, -0.2302, 0.1593, -0.4779, -0.1577, -0.7036,
-0.5208, 0.4676, 0.2335, 0.9372, 0.4775, -0.1995, -0.5655, -0.2336,
0.0798, -0.9315, 0.8288, -0.0946, 0.5294, 0.0223, -0.0744, 0.7821,
0.1236, -0.3705, -0.3958, -0.7528, 0.8145, -0.3204, -0.4786, -0.5135,
0.7306, 0.3208, 0.9981, -0.3959, -0.3492, -0.1118, -0.2872, 0.3596,
-0.1345, -1.0000, 0.2896, 0.2262, 0.1702, -0.3530, 0.1111, -0.0755,
-0.9565, -0.2658, 0.2530, -0.0490, -0.5834, -0.4616, 0.3937, 0.2329,
0.5620, 0.8138, -0.0288, 0.5621, 0.3811, 0.0852, -0.6049, 0.8452]],
grad_fn=<TanhBackward>)
Archive: Airline_Tweets.zip
inflating: Tweets.csv
tweet_id | airline_sentiment | airline_sentiment_confidence | negativereason | negativereason_confidence | airline | airline_sentiment_gold | name | negativereason_gold | retweet_count | text | tweet_coord | tweet_created | tweet_location | user_timezone | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 570306133677760513 | neutral | 1.0000 | NaN | NaN | Virgin America | NaN | cairdin | NaN | 0 | @VirginAmerica What @dhepburn said. | NaN | 2015-02-24 11:35:52 -0800 | NaN | Eastern Time (US & Canada) |
1 | 570301130888122368 | positive | 0.3486 | NaN | 0.0000 | Virgin America | NaN | jnardino | NaN | 0 | @VirginAmerica plus you've added commercials to the experience... tacky. | NaN | 2015-02-24 11:15:59 -0800 | NaN | Pacific Time (US & Canada) |
2 | 570301083672813571 | neutral | 0.6837 | NaN | NaN | Virgin America | NaN | yvonnalynn | NaN | 0 | @VirginAmerica I didn't today... Must mean I need to take another trip! | NaN | 2015-02-24 11:15:48 -0800 | Lets Play | Central Time (US & Canada) |
3 | 570301031407624196 | negative | 1.0000 | Bad Flight | 0.7033 | Virgin America | NaN | jnardino | NaN | 0 | @VirginAmerica it's really aggressive to blast obnoxious "entertainment" in your guests' faces & they have little recourse | NaN | 2015-02-24 11:15:36 -0800 | NaN | Pacific Time (US & Canada) |
4 | 570300817074462722 | negative | 1.0000 | Can't Tell | 1.0000 | Virgin America | NaN | jnardino | NaN | 0 | @VirginAmerica and it's a really big bad thing about it | NaN | 2015-02-24 11:14:45 -0800 | NaN | Pacific Time (US & Canada) |
4837 @SouthwestAir how do I get a companion pass
2874 @united Looks like they came through. Thanks again for the help.
8393 @JetBlue no, but we're on the flight leaving from Boston to Seattle right now. :) flight 597
1624 “@united: @rikrik__ What made you come to this? Can we help you with anything? ^JP” the service just hasn't been great.
6882 @JetBlue - Definitely no note from whoever stole from me.
Name: text, dtype: object
negative 9178
neutral 3099
positive 2363
Name: airline_sentiment, dtype: int64
negative 0.626913
neutral 0.211680
positive 0.161407
Name: airline_sentiment, dtype: float64
#library for pattern matching
import re
#define a function for text cleaning
def preprocessor(text):
#convering text to lower case
text = text.lower()
#remove user mentions
text = re.sub(r'@[A-Za-z0-9]+','',text)
#remove hashtags
#text = re.sub(r'#[A-Za-z0-9]+','',text)
#remove links
text = re.sub(r'http\S+', '', text)
#split token to remove extra spaces
tokens = text.split()
#join tokens by space
return " ".join(tokens)
array(['is flight 769 on it\'s way? was supposed to take off 30 minutes ago. website still shows "on time" not "in flight". thanks.',
'julie andrews all the way though was very impressive! no to',
'wish you flew out of atlanta... soon?',
'julie andrews. hands down.',
'will flights be leaving dallas for la on february 24th?'],
dtype=object)
##3.3 Preparing Input and Output Data
Preparing Output
Preparing Input Data
Text(0.5, 1.0, 'Histogram: Length of sentences')
# library for progress bar
from tqdm import notebook
# create an empty list to save integer sequence
sent_id = []
# iterate over each tweet
for i in notebook.tqdm(range(len(text))):
encoded_sent = tokenizer.encode(text[i],
add_special_tokens = True,
max_length = max_len,
truncation = True,
pad_to_max_length='right')
# saving integer sequence to a list
sent_id.append(encoded_sent)
Integer Sequence: [101, 2054, 2056, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
##3.4 Training and Validation Data
# Use train_test_split to split our data into train and validation sets
from sklearn.model_selection import train_test_split
# Use 90% for training and 10% for validation.
train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(sent_id, labels, random_state=2018, test_size=0.1, stratify=labels)
# Do the same for the masks.
train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.1, stratify=labels)
# Convert all inputs and labels into torch tensors, the required datatype for our model.
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
# The DataLoader needs to know our batch size for training, so we specify it here.
# For fine-tuning BERT on a specific task, the authors recommend a batch size of 16 or 32.
#define a batch size
batch_size = 32
# Create the DataLoader for our training set.
#Dataset wrapping tensors.
train_data = TensorDataset(train_inputs, train_masks, train_labels)
#define a sampler for sampling the data during training
#random sampler samples randomly from a dataset
#sequential sampler samples sequentially, always in the same order
train_sampler = RandomSampler(train_data)
#represents a iterator over a dataset. Supports batching, customized data loading order
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
# Create the DataLoader for our validation set.
#Dataset wrapping tensors.
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
#define a sequential sampler
#This samples data in a sequential order
validation_sampler = SequentialSampler(validation_data)
#create a iterator over the dataset
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)
tensor([[ 101, 5064, 2090, 1040, 2546, 2860, 1998, 8764, 1045, 2288,
19030, 2013, 2260, 2497, 2035, 1996, 2126, 2000, 4601, 2290,
2006, 20304, 2475, 1029, 102],
[ 101, 2129, 2055, 2070, 4086, 19430, 3642, 2000, 2393, 2033,
3942, 2026, 2155, 999, 999, 999, 1012, 1012, 1012, 1012,
3531, 999, 999, 999, 102],
[ 101, 7987, 27225, 1523, 1024, 2735, 2091, 2005, 2054, 1012,
1001, 1058, 2595, 3736, 7959, 3723, 25514, 1524, 102, 0,
0, 0, 0, 0, 0],
[ 101, 4931, 4364, 1010, 2339, 2106, 2026, 2197, 3462, 7796,
2033, 1014, 19637, 1029, 102, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 3357, 1015, 1024, 2022, 2625, 2915, 1012, 3357, 1016,
1024, 13399, 6304, 2060, 3182, 2084, 10474, 1012, 3357, 1017,
1024, 2123, 1005, 1056, 102],
[ 101, 6146, 2581, 3823, 2013, 7921, 2012, 6921, 3199, 1999,
2258, 2286, 1001, 20704, 18372, 2243, 102, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 3403, 2012, 6522, 2078, 2006, 4946, 2062, 2084, 2019,
3178, 2144, 4899, 1038, 1013, 1039, 1997, 2053, 4796, 1010,
3531, 2131, 2149, 2125, 102],
[ 101, 2024, 2017, 14763, 2005, 3462, 26727, 2157, 2085, 102,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 2003, 1996, 4037, 2091, 1029, 102, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 1045, 1005, 1049, 2667, 2000, 4638, 2046, 2026, 2184,
1024, 2753, 3286, 14931, 3462, 1056, 7382, 2006, 1996, 15363,
4037, 1998, 2009, 1005, 102],
[ 101, 1040, 7583, 1996, 2171, 102, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 1523, 1024, 3376, 2915, 1012, 1012, 4283, 2005, 6631,
1012, 2478, 1001, 4875, 8873, 2000, 2695, 1029, 1025, 1007,
1524, 2115, 6160, 999, 102],
[ 101, 2058, 1016, 2847, 1998, 2403, 2781, 2006, 2907, 1998,
10320, 1012, 4283, 2005, 1996, 2393, 999, 1001, 6304, 2121,
7903, 2063, 1001, 12077, 102],
[ 101, 1045, 2052, 2293, 2000, 2175, 2000, 1996, 5865, 2265,
1625, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 2057, 1005, 2310, 1006, 3462, 4008, 23777, 1007, 2042,
3564, 2006, 1996, 16985, 22911, 1997, 5887, 2050, 2005, 1037,
2096, 2085, 3403, 2006, 102],
[ 101, 15536, 8873, 2347, 1005, 1056, 2551, 27120, 1012, 22333,
16742, 2128, 22278, 1012, 2017, 2741, 2033, 2000, 1037, 3309,
2005, 2484, 2847, 2007, 102],
[ 101, 2293, 2129, 2017, 2064, 1005, 1056, 2131, 2019, 4005,
2006, 1996, 3042, 1998, 1996, 12978, 2291, 17991, 2039, 2006,
2017, 102, 0, 0, 0],
[ 101, 1045, 4741, 2006, 2907, 2005, 2048, 2847, 1010, 2069,
2000, 2031, 2026, 2655, 1012, 2428, 23579, 1012, 102, 0,
0, 0, 0, 0, 0],
[ 101, 2748, 1012, 1045, 6618, 2019, 3178, 2001, 1037, 2146,
2438, 2051, 2000, 2907, 2077, 3228, 2039, 1012, 2064, 8307,
2655, 2033, 1029, 102, 0],
[ 101, 1011, 10885, 2420, 1024, 7882, 1010, 15544, 5753, 1999,
10975, 2005, 1996, 2305, 1024, 1002, 6352, 1010, 3974, 1037,
2154, 2000, 28781, 1998, 102],
[ 101, 13970, 12269, 2005, 2025, 8014, 3462, 2989, 7599, 2013,
1040, 2546, 2860, 2023, 2851, 1012, 2142, 2788, 2034, 2000,
6634, 1012, 1012, 1012, 102],
[ 101, 2044, 2035, 1045, 2031, 2042, 2083, 2006, 2023, 4440,
1010, 2064, 2017, 2131, 2033, 2006, 2178, 8582, 2188, 1029,
102, 0, 0, 0, 0],
[ 101, 4931, 2045, 2033, 2153, 2013, 7483, 10047, 2145, 2006,
2907, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 2045, 2003, 2242, 3308, 2007, 2017, 4037, 1999, 23591,
18059, 102, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 2003, 1996, 2190, 2126, 2000, 2128, 1011, 15908, 2033,
2007, 2026, 2028, 2995, 2293, 1010, 6023, 1999, 3915, 1005,
1055, 4827, 3007, 1001, 102],
[ 101, 4223, 2003, 2200, 4129, 102, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 1040, 2386, 6914, 1012, 4012, 102, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 2009, 1005, 1055, 9145, 2108, 2006, 2907, 2000, 2017,
999, 999, 2428, 2423, 2781, 2000, 2689, 1037, 3462, 999,
999, 102, 0, 0, 0],
[ 101, 6228, 3277, 1012, 3504, 2066, 2027, 2288, 2009, 4964,
999, 4283, 2005, 2115, 5142, 1012, 102, 0, 0, 0,
0, 0, 0, 0, 0],
[ 101, 2092, 2049, 2340, 1024, 3429, 3286, 1998, 2074, 2288,
2019, 10373, 2008, 2026, 2340, 3286, 3462, 2003, 8394, 1011,
2008, 2015, 2025, 2157, 102],
[ 101, 4067, 2017, 2005, 14120, 1012, 1012, 1012, 2026, 12191,
2003, 1999, 2026, 4524, 1998, 1045, 2342, 2009, 2005, 2147,
1012, 1045, 2572, 5191, 102],
[ 101, 7864, 1010, 5327, 2005, 4248, 3247, 1010, 5079, 2000,
2178, 13445, 2057, 2074, 2363, 1024, 6378, 1011, 15045, 3296,
16122, 8917, 1011, 5391, 102]])
Shape of Hidden States: torch.Size([32, 25, 768])
Shape of CLS Hidden State: torch.Size([32, 768])
The pretrained model is trained on the general domain corpus. So, finetuning the pretrained model helps in the capturing the domain specific features from our custom dataset
Every pretrained model is trained using 2 different layers : BackBone and Head
Hence, we can finetune the pretrained model in 2 ways
Fine-Tuning only Head (or Dense Layer) 1.1 CLS token 1.2 Hidden states
Fine-Tuning both Backbone & Head 2.1 CLS token 2.2 Hidden states
As the name suggests, in this approach, we freeze the backbone and train only the head or dense layer.
#importing nn module
import torch.nn as nn
class classifier(nn.Module):
#define the layers and wrappers used by model
def __init__(self, bert):
#constructor
super(classifier, self).__init__()
#bert model
self.bert = bert
# dense layer 1
self.fc1 = nn.Linear(768,512)
#dense layer 2 (Output layer)
self.fc2 = nn.Linear(512,3)
#dropout layer
self.dropout = nn.Dropout(0.1)
#relu activation function
self.relu = nn.ReLU()
#softmax activation function
self.softmax = nn.LogSoftmax(dim=1)
#define the forward pass
def forward(self, sent_id, mask):
#pass the inputs to the model
all_hidden_states, cls_hidden_state = self.bert(sent_id, attention_mask=mask)
#pass CLS hidden state to dense layer
x = self.fc1(cls_hidden_state)
#Apply ReLU activation function
x = self.relu(x)
#Apply Dropout
x = self.dropout(x)
#pass input to the output layer
x = self.fc2(x)
#apply softmax activation
x = self.softmax(x)
return x
classifier(
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 768, padding_idx=0)
(position_embeddings): Embedding(512, 768)
(token_type_embeddings): Embedding(2, 768)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(2): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(3): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(4): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(5): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(6): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(7): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(8): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(9): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(10): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(11): BertLayer(
(attention): BertAttention(
(self): BertSelfAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=True)
(value): Linear(in_features=768, out_features=768, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=768, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=768, out_features=3072, bias=True)
)
(output): BertOutput(
(dense): Linear(in_features=3072, out_features=768, bias=True)
(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(fc1): Linear(in_features=768, out_features=512, bias=True)
(fc2): Linear(in_features=512, out_features=3, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(relu): ReLU()
(softmax): LogSoftmax(dim=1)
)
tensor([[-0.8960, -1.3895, -1.0712],
[-0.8843, -1.2947, -1.1615],
[-0.9099, -1.2691, -1.1509],
[-0.8514, -1.2819, -1.2186],
[-0.9650, -1.2425, -1.1076],
[-0.9353, -1.3501, -1.0546],
[-0.8627, -1.3478, -1.1452],
[-0.9753, -1.2643, -1.0774],
[-0.8849, -1.4208, -1.0620],
[-0.9231, -1.2884, -1.1178],
[-0.9613, -1.2478, -1.1073],
[-0.9618, -1.2306, -1.1219],
[-0.8743, -1.3402, -1.1361],
[-0.9644, -1.2575, -1.0954],
[-0.9228, -1.3100, -1.1003],
[-0.8805, -1.2835, -1.1765],
[-0.8361, -1.3504, -1.1793],
[-0.8948, -1.3033, -1.1404],
[-0.8801, -1.3966, -1.0853],
[-0.9366, -1.2906, -1.0998],
[-0.8680, -1.3155, -1.1652],
[-0.8701, -1.2654, -1.2073],
[-0.8920, -1.3223, -1.1281],
[-0.8964, -1.3177, -1.1264],
[-0.7724, -1.4183, -1.2175],
[-0.9444, -1.3060, -1.0783],
[-0.8942, -1.3986, -1.0668],
[-0.9210, -1.2643, -1.1412],
[-0.9566, -1.2440, -1.1161],
[-0.9399, -1.2842, -1.1012],
[-0.8473, -1.3528, -1.1617],
[-0.9309, -1.3419, -1.0658]], device='cuda:0',
grad_fn=<LogSoftmaxBackward>)
The model has 395,267 trainable parameters
Text(0.5, 1.0, 'Class Distribution')
Class Weights: [0.53170625 1.57470152 2.06517139]
Loss: tensor(1.1441, device='cuda:0', grad_fn=<NllLossBackward>)
The deep learning model is trained in the form of epochs where in each epoch consists of several batches.
During training, for each batch, we need to
Where as during evaluation, for each batch, we need to
Training: Epoch -> Batch -> Forward Pass -> Compute loss -> Backpropagate loss -> Update weights
Evaluation: Epoch -> Batch -> Forward Pass -> Compute loss
Hence, for each epoch, we have a training phase and a validation phase. After each batch we need to:
Training phase
#define a function for training the model
def train():
print("\nTraining.....")
#set the model on training phase - Dropout layers are activated
model.train()
#record the current time
t0 = time.time()
#initialize loss and accuracy to 0
total_loss, total_accuracy = 0, 0
#Create a empty list to save the model predictions
total_preds=[]
#for every batch
for step,batch in enumerate(train_dataloader):
# Progress update after every 40 batches.
if step % 40 == 0 and not step == 0:
# Calculate elapsed time in minutes.
elapsed = format_time(time.time() - t0)
# Report progress.
print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))
#push the batch to gpu
batch = tuple(t.to(device) for t in batch)
#unpack the batch into separate variables
# `batch` contains three pytorch tensors:
# [0]: input ids
# [1]: attention masks
# [2]: labels
sent_id, mask, labels = batch
# Always clear any previously calculated gradients before performing a
# backward pass. PyTorch doesn't do this automatically.
model.zero_grad()
# Perform a forward pass. This returns the model predictions
preds = model(sent_id, mask)
#compute the loss between actual and predicted values
# loss = cross_entropy(preds, labels)
loss = cross_entropy(preds, labels.long())
# Accumulate the training loss over all of the batches so that we can
# calculate the average loss at the end. `loss` is a Tensor containing a
# single value; the `.item()` function just returns the Python value
# from the tensor.
total_loss = total_loss + loss.item()
# Perform a backward pass to calculate the gradients.
loss.backward()
# Update parameters and take a step using the computed gradient.
# The optimizer dictates the "update rule"--how the parameters are
# modified based on their gradients, the learning rate, etc.
optimizer.step()
#The model predictions are stored on GPU. So, push it to CPU
preds=preds.detach().cpu().numpy()
#Accumulate the model predictions of each batch
total_preds.append(preds)
#compute the training loss of a epoch
avg_loss = total_loss / len(train_dataloader)
#The predictions are in the form of (no. of batches, size of batch, no. of classes).
#So, reshaping the predictions in form of (number of samples, no. of classes)
total_preds = np.concatenate(total_preds, axis=0)
#returns the loss and predictions
return avg_loss, total_preds
Evaluation phase
#define a function for evaluating the model
def evaluate():
print("\nEvaluating.....")
#set the model on training phase - Dropout layers are deactivated
model.eval()
#record the current time
t0 = time.time()
#initialize the loss and accuracy to 0
total_loss, total_accuracy = 0, 0
#Create a empty list to save the model predictions
total_preds = []
#for each batch
for step,batch in enumerate(validation_dataloader):
# Progress update every 40 batches.
if step % 40 == 0 and not step == 0:
# Calculate elapsed time in minutes.
elapsed = format_time(time.time() - t0)
# Report progress.
print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(validation_dataloader), elapsed))
#push the batch to gpu
batch = tuple(t.to(device) for t in batch)
#unpack the batch into separate variables
# `batch` contains three pytorch tensors:
# [0]: input ids
# [1]: attention masks
# [2]: labels
sent_id, mask, labels = batch
#deactivates autograd
with torch.no_grad():
# Perform a forward pass. This returns the model predictions
preds = model(sent_id, mask)
#compute the validation loss between actual and predicted values
loss = cross_entropy(preds,labels)
# Accumulate the validation loss over all of the batches so that we can
# calculate the average loss at the end. `loss` is a Tensor containing a
# single value; the `.item()` function just returns the Python value
# from the tensor.
total_loss = total_loss + loss.item()
#The model predictions are stored on GPU. So, push it to CPU
preds=preds.detach().cpu().numpy()
#Accumulate the model predictions of each batch
total_preds.append(preds)
#compute the validation loss of a epoch
avg_loss = total_loss / len(validation_dataloader)
#The predictions are in the form of (no. of batches, size of batch, no. of classes).
#So, reshaping the predictions in form of (number of samples, no. of classes)
total_preds = np.concatenate(total_preds, axis=0)
return avg_loss, total_preds
#Assign the initial loss to infinite
best_valid_loss = float('inf')
#create a empty list to store training and validation loss of each epoch
train_losses=[]
valid_losses=[]
epochs = 5
#for each epoch
for epoch in range(epochs):
print('\n....... epoch {:} / {:} .......'.format(epoch + 1, epochs))
#train model
train_loss, _ = train()
#evaluate model
valid_loss, _ = evaluate()
#save the best model
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'saved_weights.pt')
#accumulate training and validation loss
train_losses.append(train_loss)
valid_losses.append(valid_loss)
print(f'\nTraining Loss: {train_loss:.3f}')
print(f'Validation Loss: {valid_loss:.3f}')
print("")
print("Training complete!")
....... epoch 1 / 5 .......
Training.....
Batch 40 of 412. Elapsed: 0:00:02.
Batch 80 of 412. Elapsed: 0:00:04.
Batch 120 of 412. Elapsed: 0:00:06.
Batch 160 of 412. Elapsed: 0:00:08.
Batch 200 of 412. Elapsed: 0:00:10.
Batch 240 of 412. Elapsed: 0:00:11.
Batch 280 of 412. Elapsed: 0:00:13.
Batch 320 of 412. Elapsed: 0:00:15.
Batch 360 of 412. Elapsed: 0:00:17.
Batch 400 of 412. Elapsed: 0:00:19.
Evaluating.....
Batch 40 of 46. Elapsed: 0:00:02.
Training Loss: 0.981
Validation Loss: 0.836
....... epoch 2 / 5 .......
Training.....
Batch 40 of 412. Elapsed: 0:00:02.
Batch 80 of 412. Elapsed: 0:00:04.
Batch 120 of 412. Elapsed: 0:00:06.
Batch 160 of 412. Elapsed: 0:00:08.
Batch 200 of 412. Elapsed: 0:00:10.
Batch 240 of 412. Elapsed: 0:00:12.
Batch 280 of 412. Elapsed: 0:00:14.
Batch 320 of 412. Elapsed: 0:00:16.
Batch 360 of 412. Elapsed: 0:00:18.
Batch 400 of 412. Elapsed: 0:00:20.
Evaluating.....
Batch 40 of 46. Elapsed: 0:00:02.
Training Loss: 0.851
Validation Loss: 0.715
....... epoch 3 / 5 .......
Training.....
Batch 40 of 412. Elapsed: 0:00:02.
Batch 80 of 412. Elapsed: 0:00:04.
Batch 120 of 412. Elapsed: 0:00:06.
Batch 160 of 412. Elapsed: 0:00:08.
Batch 200 of 412. Elapsed: 0:00:10.
Batch 240 of 412. Elapsed: 0:00:12.
Batch 280 of 412. Elapsed: 0:00:14.
Batch 320 of 412. Elapsed: 0:00:16.
Batch 360 of 412. Elapsed: 0:00:18.
Batch 400 of 412. Elapsed: 0:00:20.
Evaluating.....
Batch 40 of 46. Elapsed: 0:00:02.
Training Loss: 0.799
Validation Loss: 0.872
....... epoch 4 / 5 .......
Training.....
Batch 40 of 412. Elapsed: 0:00:02.
Batch 80 of 412. Elapsed: 0:00:04.
Batch 120 of 412. Elapsed: 0:00:06.
Batch 160 of 412. Elapsed: 0:00:08.
Batch 200 of 412. Elapsed: 0:00:10.
Batch 240 of 412. Elapsed: 0:00:12.
Batch 280 of 412. Elapsed: 0:00:14.
Batch 320 of 412. Elapsed: 0:00:16.
Batch 360 of 412. Elapsed: 0:00:18.
Batch 400 of 412. Elapsed: 0:00:20.
Evaluating.....
Batch 40 of 46. Elapsed: 0:00:02.
Training Loss: 0.780
Validation Loss: 0.679
....... epoch 5 / 5 .......
Training.....
Batch 40 of 412. Elapsed: 0:00:02.
Batch 80 of 412. Elapsed: 0:00:04.
Batch 120 of 412. Elapsed: 0:00:06.
Batch 160 of 412. Elapsed: 0:00:08.
Batch 200 of 412. Elapsed: 0:00:10.
Batch 240 of 412. Elapsed: 0:00:12.
Batch 280 of 412. Elapsed: 0:00:14.
Batch 320 of 412. Elapsed: 0:00:16.
Batch 360 of 412. Elapsed: 0:00:18.
Batch 400 of 412. Elapsed: 0:00:20.
Evaluating.....
Batch 40 of 46. Elapsed: 0:00:02.
Training Loss: 0.760
Validation Loss: 0.677
Training complete!
##4.6 Model Evaluation
<All keys matched successfully>
Evaluating.....
Batch 40 of 46. Elapsed: 0:00:02.
0.6771649757157201