diff --git a/finetuning.ipynb b/finetuning.ipynb index bdac2c2..4212625 100644 --- a/finetuning.ipynb +++ b/finetuning.ipynb @@ -59,16 +59,25 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: TOKENIZERS_PARALLELISM=true\n" + ] + } + ], "source": [ + "%env TOKENIZERS_PARALLELISM=true\n", "import os\n", "from sklearn.metrics import classification_report\n", "import torch\n", "import torch.nn as nn\n", "import transformers\n", - "from transformers import BertModel, BertTokenizer\n", + "from transformers import BertModel, DistilBertTokenizerFast\n", "\n", "from torch_shallow_neural_classifier import TorchShallowNeuralClassifier\n", "from torch_rnn_classifier import TorchRNNModel\n", @@ -76,21 +85,22 @@ "from torch_rnn_classifier import TorchRNNClassifierModel\n", "from torch_rnn_classifier import TorchRNNClassifier\n", "import sst\n", - "import utils" + "import utils\n" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ - "utils.fix_random_seeds()" + "utils.fix_random_seeds()\n", + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -106,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +134,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -147,20 +157,20 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "bert_tokenizer = BertTokenizer.from_pretrained(weights_name)" + "bert_tokenizer = DistilBertTokenizerFast.from_pretrained(weights_name)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "bert_model = BertModel.from_pretrained(weights_name)" + "bert_model = BertModel.from_pretrained(weights_name).to(device)" ] }, { @@ -172,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -190,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -203,16 +213,16 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])" + "dict_keys(['input_ids', 'attention_mask'])" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -230,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -240,7 +250,7 @@ " [101, 15035, 3520, 156, 14787, 13327, 4455, 28026, 1116, 102, 0, 0]]" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -260,7 +270,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -269,7 +279,7 @@ "[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]" ] }, - "execution_count": 13, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -287,12 +297,12 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "X_example = torch.tensor(example_ids['input_ids'])\n", - "X_example_mask = torch.tensor(example_ids['attention_mask'])\n", + "X_example = torch.tensor(example_ids['input_ids']).to(device)\n", + "X_example_mask = torch.tensor(example_ids['attention_mask']).to(device)\n", "\n", "with torch.no_grad():\n", " reps = bert_model(X_example, attention_mask=X_example_mask)" @@ -307,7 +317,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -316,7 +326,7 @@ "torch.Size([2, 768])" ] }, - "execution_count": 15, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -336,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -345,7 +355,7 @@ "torch.Size([2, 12, 768])" ] }, - "execution_count": 16, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -379,16 +389,16 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "def bert_phi(text):\n", " input_ids = bert_tokenizer.encode(text, add_special_tokens=True)\n", - " X = torch.tensor([input_ids])\n", + " X = torch.tensor([input_ids]).to(device)\n", " with torch.no_grad():\n", " reps = bert_model(X)\n", - " return reps.last_hidden_state.squeeze(0).numpy()" + " return reps.last_hidden_state.squeeze(0).to(\"cpu\").numpy()" ] }, { @@ -402,7 +412,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -421,7 +431,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -439,7 +449,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -466,8 +476,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 32min 44s, sys: 52.8 s, total: 33min 37s\n", - "Wall time: 8min 24s\n" + "CPU times: user 2min 13s, sys: 835 ms, total: 2min 14s\n", + "Wall time: 2min 14s\n" ] } ], @@ -484,8 +494,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 4min 14s, sys: 7.2 s, total: 4min 22s\n", - "Wall time: 1min 5s\n" + "CPU times: user 17.2 s, sys: 67.9 ms, total: 17.3 s\n", + "Wall time: 17.3 s\n" ] } ], @@ -502,33 +512,33 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "model = TorchShallowNeuralClassifier(\n", " early_stopping=True,\n", - " hidden_dim=300)" + " hidden_dim=300, )" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Stopping after epoch 45. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.156181752681732" + "Stopping after epoch 130. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.291560411453247" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 21.3 s, sys: 2.56 s, total: 23.9 s\n", - "Wall time: 8.85 s\n" + "CPU times: user 2min 3s, sys: 223 ms, total: 2min 3s\n", + "Wall time: 13.2 s\n" ] } ], @@ -538,7 +548,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -547,7 +557,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -556,13 +566,13 @@ "text": [ " precision recall f1-score support\n", "\n", - " negative 0.696 0.787 0.739 428\n", - " neutral 0.342 0.279 0.308 229\n", - " positive 0.756 0.732 0.744 444\n", + " negative 0.700 0.808 0.751 428\n", + " neutral 0.500 0.192 0.278 229\n", + " positive 0.715 0.836 0.771 444\n", "\n", - " accuracy 0.659 1101\n", - " macro avg 0.598 0.600 0.597 1101\n", - "weighted avg 0.647 0.659 0.651 1101\n", + " accuracy 0.691 1101\n", + " macro avg 0.638 0.612 0.600 1101\n", + "weighted avg 0.665 0.691 0.660 1101\n", "\n" ] } @@ -582,7 +592,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ @@ -596,14 +606,14 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 36, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Stopping after epoch 39. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.242022633552551" + "Stopping after epoch 137. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 5.239248812198639" ] }, { @@ -612,16 +622,16 @@ "text": [ " precision recall f1-score support\n", "\n", - " negative 0.701 0.806 0.750 428\n", - " neutral 0.435 0.162 0.236 229\n", - " positive 0.714 0.842 0.773 444\n", + " negative 0.699 0.799 0.746 428\n", + " neutral 0.463 0.162 0.239 229\n", + " positive 0.705 0.845 0.768 444\n", "\n", - " accuracy 0.687 1101\n", - " macro avg 0.617 0.603 0.586 1101\n", - "weighted avg 0.651 0.687 0.652 1101\n", + " accuracy 0.685 1101\n", + " macro avg 0.622 0.602 0.585 1101\n", + "weighted avg 0.652 0.685 0.650 1101\n", "\n", - "CPU times: user 38min 14s, sys: 1min 2s, total: 39min 17s\n", - "Wall time: 9min 49s\n" + "CPU times: user 4min 40s, sys: 1.14 s, total: 4min 41s\n", + "Wall time: 2min 45s\n" ] } ], @@ -652,7 +662,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -667,14 +677,14 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Stopping after epoch 32. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.7171962857246399" + "Stopping after epoch 34. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.5571811199188232" ] }, { @@ -683,16 +693,16 @@ "text": [ " precision recall f1-score support\n", "\n", - " negative 0.702 0.776 0.737 428\n", - " neutral 0.351 0.236 0.282 229\n", - " positive 0.747 0.797 0.771 444\n", + " negative 0.741 0.694 0.717 428\n", + " neutral 0.370 0.266 0.310 229\n", + " positive 0.690 0.831 0.754 444\n", "\n", - " accuracy 0.672 1101\n", - " macro avg 0.600 0.603 0.597 1101\n", - "weighted avg 0.647 0.672 0.656 1101\n", + " accuracy 0.660 1101\n", + " macro avg 0.600 0.597 0.593 1101\n", + "weighted avg 0.643 0.660 0.647 1101\n", "\n", - "CPU times: user 38min 45s, sys: 1min 39s, total: 40min 24s\n", - "Wall time: 10min 6s\n" + "CPU times: user 6min 56s, sys: 1min 21s, total: 8min 17s\n", + "Wall time: 3min 7s\n" ] } ], @@ -726,7 +736,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -781,14 +791,14 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "class HfBertClassifier(TorchShallowNeuralClassifier):\n", " def __init__(self, weights_name, *args, **kwargs):\n", " self.weights_name = weights_name\n", - " self.tokenizer = BertTokenizer.from_pretrained(self.weights_name)\n", + " self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.weights_name)\n", " super().__init__(*args, **kwargs)\n", " self.params += ['weights_name']\n", "\n", @@ -827,7 +837,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -837,7 +847,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -862,34 +872,34 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Finished epoch 1 of 1; error is 184.64238105341792" + "Finished epoch 1 of 1; error is 93.697665332816541" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Best params: {'eta': 5e-05, 'gradient_accumulation_steps': 4, 'hidden_dim': 200}\n", - "Best score: 0.587\n", + "Best params: {'eta': 0.0001, 'gradient_accumulation_steps': 8, 'hidden_dim': 200}\n", + "Best score: 0.583\n", " precision recall f1-score support\n", "\n", - " negative 0.686 0.930 0.790 428\n", - " neutral 0.514 0.079 0.136 229\n", - " positive 0.763 0.836 0.798 444\n", + " negative 0.640 0.967 0.770 428\n", + " neutral 0.375 0.013 0.025 229\n", + " positive 0.794 0.797 0.796 444\n", "\n", - " accuracy 0.715 1101\n", - " macro avg 0.655 0.615 0.575 1101\n", - "weighted avg 0.682 0.715 0.657 1101\n", + " accuracy 0.700 1101\n", + " macro avg 0.603 0.593 0.530 1101\n", + "weighted avg 0.647 0.700 0.625 1101\n", "\n", - "CPU times: user 1h 27min 12s, sys: 11min 18s, total: 1h 38min 31s\n", - "Wall time: 1h 37min 44s\n" + "CPU times: user 1h 15min 20s, sys: 29 s, total: 1h 15min 49s\n", + "Wall time: 1h 11min 54s\n" ] } ], @@ -912,7 +922,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -921,7 +931,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -931,7 +941,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -943,7 +953,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -953,14 +963,14 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 52, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Stopping after epoch 9. Validation score did not improve by tol=1e-05 for more than 5 epochs. Final error is 11.503188711278199" + "Stopping after epoch 10. Validation score did not improve by tol=1e-05 for more than 5 epochs. Final error is 6.126567473358591" ] }, { @@ -969,16 +979,16 @@ "text": [ " precision recall f1-score support\n", "\n", - " negative 0.816 0.754 0.784 912\n", - " neutral 0.332 0.501 0.400 389\n", - " positive 0.881 0.756 0.813 909\n", + " negative 0.799 0.707 0.750 912\n", + " neutral 0.312 0.409 0.354 389\n", + " positive 0.830 0.815 0.822 909\n", "\n", - " accuracy 0.710 2210\n", - " macro avg 0.676 0.670 0.666 2210\n", - "weighted avg 0.758 0.710 0.728 2210\n", + " accuracy 0.699 2210\n", + " macro avg 0.647 0.644 0.642 2210\n", + "weighted avg 0.726 0.699 0.710 2210\n", "\n", - "CPU times: user 9min 54s, sys: 1min 22s, total: 11min 17s\n", - "Wall time: 11min 16s\n" + "CPU times: user 9min 1s, sys: 1.09 s, total: 9min 3s\n", + "Wall time: 8min 57s\n" ] } ], @@ -991,6 +1001,13 @@ " assess_dataframes=test_df,\n", " vectorize=False) # Pass in the BERT hidden state directly!" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -1009,7 +1026,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.10.12" } }, "nbformat": 4,