{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"x4N1GRI4PB61","outputId":"02ff3dba-b805-47ac-d31c-faa12a301762"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8v3yfIgiPC-e"},"outputs":[],"source":["path = \"/content/drive/My Drive/KhoaLuan/\""]},{"cell_type":"code","execution_count":null,"metadata":{"id":"d2pFgULmOl9s"},"outputs":[],"source":["import torch\n","import numpy as np\n","import pandas as pd\n","import seaborn as sns\n","import matplotlib.pyplot as plt\n","\n","from sklearn.model_selection import StratifiedKFold\n","from sklearn.metrics import classification_report, confusion_matrix\n","\n","import torch.nn as nn\n","from torch.optim import AdamW\n","from torch.utils.data import Dataset, DataLoader\n","\n","from transformers import get_linear_schedule_with_warmup, AutoTokenizer, AutoModel, logging\n","\n","import warnings\n","import time\n","import pickle\n","warnings.filterwarnings(\"ignore\")\n","\n","logging.set_verbosity_error()"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QoQ54muyOl9x"},"outputs":[],"source":["def seed_everything(seed_value):\n"," np.random.seed(seed_value)\n"," torch.manual_seed(seed_value)\n","\n"," if torch.cuda.is_available():\n"," torch.cuda.manual_seed(seed_value)\n"," torch.cuda.manual_seed_all(seed_value)\n"," torch.backends.cudnn.deterministic = True\n"," torch.backends.cudnn.benchmark = True\n","\n","seed_everything(86)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":177,"referenced_widgets":["cb4ab79526d440ca83477ff3a04a08a1","5b8c0cae78494560b77e662f62390b6a","a9fd54f8549448c7a0bcddb54f6291a0","975e44350bac401e9a252d91690f403c","96c86a9389b54e93a0b8bc4f1f8062c5","dcf1858dfb5d41cba8b389c8ce5e25df","f1283c390edb43d7a594b1698edb88d6","3194b7dfd003461f820aaf76f6f28e06","e6b4b08966444e48a27761a636f82d2f","29abb95da6bb456b8c93a908b224d595","9fce6620292549bd8680cab5fd9897e3","55efea9e72ee4c2a8dcd996d94b9d7d5","9f49bdbc16574981a5690557df4e86ed","ceec605689804459b4f4ef4a04ed350b","6c0a0e0995044ac68407772af686d110","e25ee9487fa74a46afa7312c114dc3c3","6b9b61cb50434953bdc9ffcc4927ce88","6e4032fac7cd44c287999dc2cc0140f3","b5e76d03302d4b1abff06a50cdcb1ae7","302f6ad8468d4b1ca5e84ff1be0a26cf","607a095cb82f46579d5225425c46d74c","37cfc2155b6942b0bdc2dc15d38d318b","b1a6ea2be64045e79757be9464101305","31f972a384f5484bb2e650f51ab0c07b","9087321a22a04369b3725d5e2c1b66ff","38afe0e6db004016a5041b44d20af4cd","ede18324dd254411adc81ec2060c1205","0b76d88f6b944d168ac720508995afb4","acdf3e8b1b6947fbbeaa4a9dc6bcfd1a","d4bd359fdc0d463fbf13a7939347c025","f0568aaa8be14d69938b16b28520bf58","72905e749d2f4319829bc847485ec12e","d31b652832bf4705a51240f52d3e123e","3f323509e43341078d14adba6014d4ac","b60001ec86c240819b4d1a3c6c36ef53","145fd1ddd7594e028b80d854b6ab1aee","1d3278e129d34725a5464966ea79096c","edf4474b02894f178f85e6753a9a8626","aa41fbef0b4440cbb3d286bb56e76221","304f1513ed774371af3f0a31a2ab62ae","4e66e5fd590049f3900ed96122be7765","c574f94bd30c4f919b27178f7663f04e","080e1c9cf237477abb8a4a886e24909d","454a25bd88f64938b4b479400167c1a0","14adb10381eb4f4481093c14aa902146","98f0555d5062403a9699411c89d313c7","d096bd334125477f9e566c5ee4ef77a4","91446629c2554a59aa3f32b845fb7551","d73cf8f1c1254f5d87846e0e35bf41b8","111f8a76f73a4752bd64cf80bf30c318","040ddedda1174ab6b9408e41c97e2e5c","83729c714d754ebfa84fef037d3c932e","578f14272fdf477dbc8693ab7bdf984e","be7cb164ab3e41eaa6c72209f0c6094e","d1b2dbd6bb054f70aac7f1fd493e3c19"]},"id":"au9SE5_YOl9x","outputId":"c0b300c6-23f9-44e8-b157-bce1e834caa8"},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"cb4ab79526d440ca83477ff3a04a08a1","version_major":2,"version_minor":0},"text/plain":["tokenizer_config.json: 0%| | 0.00/311 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"55efea9e72ee4c2a8dcd996d94b9d7d5","version_major":2,"version_minor":0},"text/plain":["vocab.txt: 0%| | 0.00/895k [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b1a6ea2be64045e79757be9464101305","version_major":2,"version_minor":0},"text/plain":["bpe.codes: 0%| | 0.00/1.14M [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"3f323509e43341078d14adba6014d4ac","version_major":2,"version_minor":0},"text/plain":["added_tokens.json: 0%| | 0.00/22.0 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"14adb10381eb4f4481093c14aa902146","version_major":2,"version_minor":0},"text/plain":["special_tokens_map.json: 0%| | 0.00/167 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"}],"source":["model_name = \"bluenguyen/longformer-phobert-base-4096\" # vinai/phobert-base-v2\n","max_len = 512 # 256\n","n_classes = 13\n","tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"ItGXacmyOl9y"},"outputs":[],"source":["device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n","EPOCHS = 5\n","N_SPLITS = 5"]},{"cell_type":"markdown","metadata":{"id":"01RxDq7COl9y"},"source":["## Data"]},{"cell_type":"markdown","metadata":{"id":"l4ynf_f2Ol9z"},"source":["### Get data"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"nYjLbabaOl90"},"outputs":[],"source":["# Function to read data from JSON file\n","def get_data(path):\n"," df = pd.read_json(path, lines=True)\n"," return df\n","\n","TRAIN_PATH = path + \"train_data_162k.json\"\n","TEST_PATH = path + \"test_data_162k.json\"\n","VAL_PATH = path + \"val_data_162k.json\"\n","\n","# Read the data from JSON files\n","train_df = get_data(TRAIN_PATH)\n","test_df = get_data(TEST_PATH)\n","valid_df = get_data(VAL_PATH)\n","\n","# Combine train and validation data\n","train_df = pd.concat([train_df, valid_df], ignore_index=True)\n","\n","# Apply StratifiedKFold\n","skf = StratifiedKFold(n_splits=N_SPLITS)\n","for fold, (_, val_) in enumerate(skf.split(X=train_df, y=train_df.category)):\n"," train_df.loc[val_, \"kfold\"] = fold"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Lv08-h2nOl90","outputId":"417804bf-64a2-4f7a-ad2f-4d4990eb59c3"},"outputs":[{"data":{"text/html":["
\n","\n","
\n"," \n","
\n","
\n","
category
\n","
processed_content
\n","
kfold
\n","
\n"," \n"," \n","
\n","
96994
\n","
Xe co
\n","
fiat giới_thiệu động_cơ siêu_sạch air triển_lã...
\n","
3.0
\n","
\n","
\n","
22834
\n","
Phap luat
\n","
cận_cảnh hiện_trường quán karaoke xảy vụ bắn c...
\n","
1.0
\n","
\n","
\n","
14391
\n","
Nha dat
\n","
thổi giá kích_cầu đợt ấm đột_biến thị_trường đ...
\n","
2.0
\n","
\n","
\n","
100151
\n","
Xa hoi
\n","
quảng_ninh người_dân nuôi ngao run phóng_viên ...
\n","
0.0
\n","
\n","
\n","
31397
\n","
Giao duc
\n","
hội tư_vấn xét tuyển đại_học cao_đẳng diễn onl...
\n","
1.0
\n","
\n"," \n","
\n","
"],"text/plain":[" category processed_content kfold\n","96994 Xe co fiat giới_thiệu động_cơ siêu_sạch air triển_lã... 3.0\n","22834 Phap luat cận_cảnh hiện_trường quán karaoke xảy vụ bắn c... 1.0\n","14391 Nha dat thổi giá kích_cầu đợt ấm đột_biến thị_trường đ... 2.0\n","100151 Xa hoi quảng_ninh người_dân nuôi ngao run phóng_viên ... 0.0\n","31397 Giao duc hội tư_vấn xét tuyển đại_học cao_đẳng diễn onl... 1.0"]},"execution_count":6,"metadata":{},"output_type":"execute_result"}],"source":["train_df.sample(5)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"0NQYr2vPOl91","outputId":"857865a2-3ddd-4d20-bced-2d101e767b7d"},"outputs":[{"name":"stdout","output_type":"stream","text":["\n","RangeIndex: 138223 entries, 0 to 138222\n","Data columns (total 3 columns):\n"," # Column Non-Null Count Dtype \n","--- ------ -------------- ----- \n"," 0 category 138223 non-null object \n"," 1 processed_content 138223 non-null object \n"," 2 kfold 138223 non-null float64\n","dtypes: float64(1), object(2)\n","memory usage: 3.2+ MB\n","\n","RangeIndex: 24126 entries, 0 to 24125\n","Data columns (total 2 columns):\n"," # Column Non-Null Count Dtype \n","--- ------ -------------- ----- \n"," 0 category 24126 non-null object\n"," 1 processed_content 24126 non-null object\n","dtypes: object(2)\n","memory usage: 377.1+ KB\n"]},{"data":{"text/plain":["(None, None)"]},"execution_count":9,"metadata":{},"output_type":"execute_result"}],"source":["train_df.info(), test_df.info()"]},{"cell_type":"markdown","metadata":{"id":"3su2QfYwOl92"},"source":["### Distribution of Categories"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":767},"id":"QceYF-mFOl92","outputId":"6d47e4c4-5613-45ae-ff2e-bc83ff1fd557"},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["# Plotting with Seaborn\n","plt.figure(figsize=(12, 8))\n","\n","# Get the unique categories\n","news_categories = train_df['category'].unique()\n","\n","colors = sns.color_palette(\"husl\", len(news_categories))\n","\n","color_mapping = {article: colors[i] for i, article in enumerate(news_categories)}\n","\n","# Plot with each bar having a different color\n","ax = sns.countplot(x='category', data=train_df, palette=color_mapping)\n","\n","# Add title and labels\n","plt.title('Distribution of Categories')\n","plt.xlabel('News Category')\n","plt.ylabel('Count')\n","plt.xticks(rotation=45) # Rotate x labels for better readability\n","\n","# Display the plot\n","plt.show()"]},{"cell_type":"markdown","metadata":{"id":"ZNxd_KCGOl93"},"source":["### Distribution of length of Sentence"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":564},"id":"j81aV78YOl93","outputId":"01d28bdb-0f7f-4b6c-e437-b2296f58056f"},"outputs":[{"data":{"image/png":"","text/plain":["
"]},"metadata":{},"output_type":"display_data"}],"source":["\n","# Combine processed content from train and test dataframes\n","all_data = train_df.processed_content.tolist() + test_df.processed_content.tolist()\n","# Encode the text\n","encoded_text = [tokenizer.encode(text, add_special_tokens=True) for text in all_data]\n","\n","# Calculate the length of each encoded text\n","token_lens = [len(text) for text in encoded_text]\n","\n","# Plot the distribution of token lengths\n","plt.figure(figsize=(10, 6))\n","sns.histplot(token_lens, kde=True)\n","plt.xlim([0, max(token_lens)])\n","plt.xlabel('Token Count')\n","plt.ylabel('Frequency')\n","plt.title('Distribution of Token Count per Sentence')\n","plt.show()"]},{"cell_type":"markdown","metadata":{"id":"BjMv5-vvOl93"},"source":["## Model PhoBert"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Mz4mEyJqOl93"},"outputs":[],"source":["class NewsDataset(Dataset):\n"," def __init__(self, df, tokenizer, max_len):\n"," self.df = df\n"," self.max_len = max_len\n"," self.tokenizer = tokenizer\n","\n"," def __len__(self):\n"," return len(self.df)\n","\n"," def __getitem__(self, index):\n"," \"\"\"\n"," To customize dataset, inherit from Dataset class and implement\n"," __len__ & __getitem__\n"," __getitem__ should return\n"," data:\n"," input_ids\n"," attention_masks\n"," text\n"," targets\n"," \"\"\"\n"," row = self.df.iloc[index]\n"," text, label = self.get_input_data(row)\n","\n"," # Encode_plus will:\n"," # (1) split text into token\n"," # (2) Add the '[CLS]' and '[SEP]' token to the start and end\n"," # (3) Truncate/Pad sentence to max length\n"," # (4) Map token to their IDS\n"," # (5) Create attention mask\n"," # (6) Return a dictionary of outputs\n"," encoding = self.tokenizer.encode_plus(\n"," text,\n"," truncation=True,\n"," add_special_tokens=True,\n"," max_length=self.max_len,\n"," padding='max_length',\n"," return_attention_mask=True,\n"," return_token_type_ids=False,\n"," return_tensors='pt',\n"," )\n","\n"," return {\n"," 'text': text,\n"," 'input_ids': encoding['input_ids'].flatten(),\n"," 'attention_masks': encoding['attention_mask'].flatten(),\n"," 'targets': torch.tensor(label, dtype=torch.long),\n"," }\n","\n","\n"," def labelencoder(self, text):\n"," label_map = {\n"," 'Cong nghe': 0, 'Doi song': 1, 'Giai tri': 2, 'Giao duc': 3, 'Khoa hoc': 4,\n"," 'Kinh te': 5, 'Nha dat': 6, 'Phap luat': 7, 'The gioi': 8, 'The thao': 9,\n"," 'Van hoa': 10, 'Xa hoi': 11, 'Xe co': 12\n"," }\n"," return label_map.get(text, -1)\n","\n"," def get_input_data(self, row):\n"," text = row['processed_content']\n"," label = self.labelencoder(row['category'])\n"," return text, label"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"zMq3qvM0Ol94"},"outputs":[],"source":["class NewsClassifier(nn.Module):\n"," def __init__(self, n_classes, model_name):\n"," super(NewsClassifier, self).__init__()\n"," # Load a pre-trained BERT model\n"," self.bert = AutoModel.from_pretrained(model_name)\n"," # Dropout layer to prevent overfitting\n"," self.drop = nn.Dropout(p=0.3)\n"," # Fully-connected layer to convert BERT's hidden state to the number of classes to predict\n"," self.fc = nn.Linear(self.bert.config.hidden_size, n_classes)\n"," # Initialize weights and biases of the fully-connected layer using the normal distribution method\n"," nn.init.normal_(self.fc.weight, std=0.02)\n"," nn.init.normal_(self.fc.bias, 0)\n","\n"," def forward(self, input_ids, attention_mask):\n"," # Get the output from the BERT model\n"," last_hidden_state, output = self.bert(\n"," input_ids=input_ids,\n"," attention_mask=attention_mask,\n"," return_dict=False\n"," )\n"," # Apply dropout\n"," x = self.drop(output)\n"," # Pass through the fully-connected layer to get predictions\n"," x = self.fc(x)\n"," return x"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"X7wuM4oAOl94"},"outputs":[],"source":["def prepare_loaders(df, fold):\n"," df_train = df[df.kfold != fold].reset_index(drop=True)\n"," df_valid = df[df.kfold == fold].reset_index(drop=True)\n","\n"," train_dataset = NewsDataset(df_train, tokenizer, max_len)\n"," valid_dataset = NewsDataset(df_valid, tokenizer, max_len)\n","\n"," train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)\n"," valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=True, num_workers=2)\n","\n"," return train_loader, valid_loader"]},{"cell_type":"markdown","metadata":{"id":"h4gy1xRfOl94"},"source":["### Train"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"BRIOWDxqOl94"},"outputs":[],"source":["# Function to train the model for one epoch\n","def train(model, criterion, optimizer, train_loader, lr_scheduler):\n"," model.train()\n"," losses = []\n"," correct = 0\n","\n"," for batch_idx, data in enumerate(train_loader):\n"," input_ids = data['input_ids'].to(device)\n"," attention_mask = data['attention_masks'].to(device)\n"," targets = data['targets'].to(device)\n","\n"," optimizer.zero_grad()\n"," outputs = model(\n"," input_ids=input_ids,\n"," attention_mask=attention_mask\n"," )\n","\n"," loss = criterion(outputs, targets)\n"," _, pred = torch.max(outputs, dim=1)\n","\n"," correct += torch.sum(pred == targets)\n"," losses.append(loss.item())\n"," loss.backward()\n"," nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n"," optimizer.step()\n"," lr_scheduler.step()\n","\n"," if batch_idx % 1000 == 0:\n"," print(f'Batch {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f}, Accuracy: {correct.double() / ((batch_idx + 1) * train_loader.batch_size):.4f}')\n","\n"," train_accuracy = correct.double() / len(train_loader.dataset)\n"," avg_loss = np.mean(losses)\n"," print(f'Train Accuracy: {train_accuracy:.4f} Loss: {avg_loss:.4f}')\n","\n","# Function to evaluate the model\n","def eval(model, criterion, valid_loader, test_loader=None):\n"," model.eval()\n"," losses = []\n"," correct = 0\n","\n"," with torch.no_grad():\n"," data_loader = test_loader if test_loader else valid_loader\n"," for batch_idx, data in enumerate(data_loader):\n"," input_ids = data['input_ids'].to(device)\n"," attention_mask = data['attention_masks'].to(device)\n"," targets = data['targets'].to(device)\n","\n"," outputs = model(\n"," input_ids=input_ids,\n"," attention_mask=attention_mask\n"," )\n","\n"," loss = criterion(outputs, targets)\n"," _, pred = torch.max(outputs, dim=1)\n","\n"," correct += torch.sum(pred == targets)\n"," losses.append(loss.item())\n","\n"," dataset_size = len(test_loader.dataset) if test_loader else len(valid_loader.dataset)\n"," accuracy = correct.double() / dataset_size\n"," avg_loss = np.mean(losses)\n","\n"," if test_loader:\n"," print(f'Test Accuracy: {accuracy:.4f} Loss: {avg_loss:.4f}')\n"," else:\n"," print(f'Valid Accuracy: {accuracy:.4f} Loss: {avg_loss:.4f}')\n","\n"," return accuracy\n","\n","\n","total_start_time = time.time()\n","\n","# Main training loop\n","for fold in range(skf.n_splits):\n"," print(f'----------- Fold: {fold + 1} ------------------')\n"," train_loader, valid_loader = prepare_loaders(train_df, fold=fold)\n"," model = NewsClassifier(n_classes=13).to(device)\n"," criterion = nn.CrossEntropyLoss()\n"," optimizer = AdamW(model.parameters(), lr=2e-5)\n","\n"," lr_scheduler = get_linear_schedule_with_warmup(\n"," optimizer,\n"," num_warmup_steps=0,\n"," num_training_steps=len(train_loader) * EPOCHS\n"," )\n"," best_acc = 0\n","\n"," for epoch in range(EPOCHS):\n"," print(f'Epoch {epoch + 1}/{EPOCHS}')\n"," print('-' * 30)\n","\n"," train(model, criterion, optimizer, train_loader, lr_scheduler)\n"," val_acc = eval(model, criterion, valid_loader)\n","\n"," if val_acc > best_acc:\n"," torch.save(model.state_dict(), f'phobert_fold{fold + 1}.pth')\n"," best_acc = val_acc\n"," print(f'Best Accuracy for Fold {fold + 1}: {best_acc:.4f}')\n"," print()\n"," print(f'Finished Fold {fold + 1} with Best Accuracy: {best_acc:.4f}')\n"," print('--------------------------------------')\n","\n","\n","total_end_time = time.time()\n","\n","total_duration = total_end_time - total_start_time\n","print(f'Total training time: {total_duration:.2f} seconds')\n"]},{"cell_type":"markdown","metadata":{"id":"qwlF0GZOOl95"},"source":["### Test"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"_SCp99aOOl95"},"outputs":[],"source":["# Function to decode numeric labels into their corresponding text names\n","def decode_labels(labels):\n"," label_map = {\n"," 0: 'Cong nghe', 1: 'Doi song', 2: 'Giai tri', 3: 'Giao duc', 4: 'Khoa hoc',\n"," 5: 'Kinh te', 6: 'Nha dat', 7: 'Phap luat', 8: 'The gioi', 9: 'The thao',\n"," 10: 'Van hoa', 11: 'Xa hoi', 12: 'Xe co'\n"," }\n"," return [label_map[label] for label in labels]\n","\n","# Function to test the model\n","def test(data_loader):\n"," models = []\n"," for fold in range(skf.n_splits):\n"," model = NewsClassifier(n_classes=13)\n"," model.to(device)\n"," model.load_state_dict(torch.load(f'{path}phobert_fold{fold+1}.pth'))\n"," model.eval()\n"," models.append(model)\n","\n"," texts = []\n"," predicts = []\n"," predict_probs = []\n"," real_values = []\n","\n"," for data in data_loader:\n"," text = data['text']\n"," input_ids = data['input_ids'].to(device)\n"," attention_mask = data['attention_masks'].to(device)\n"," targets = data['targets'].to(device)\n","\n"," total_outs = []\n"," for model in models:\n"," with torch.no_grad():\n"," outputs = model(\n"," input_ids=input_ids,\n"," attention_mask=attention_mask\n"," )\n"," total_outs.append(outputs)\n","\n"," # Taking the average of predictions from 5 models\n"," total_outs = torch.stack(total_outs)\n"," _, pred = torch.max(total_outs.mean(0), dim=1)\n"," texts.extend(text)\n"," predicts.extend(pred)\n"," predict_probs.extend(total_outs.mean(0))\n"," real_values.extend(targets)\n","\n"," predicts = torch.stack(predicts).cpu().numpy()\n"," predict_probs = torch.stack(predict_probs).cpu().numpy()\n"," real_values = torch.stack(real_values).cpu().numpy()\n","\n"," # Decode numeric labels into text labels\n"," decoded_real_values = decode_labels(real_values)\n"," decoded_predicts = decode_labels(predicts)\n","\n"," # Generate classification report\n"," report = classification_report(decoded_real_values, decoded_predicts, output_dict=True)\n","\n"," # Convert to DataFrame\n"," df_report = pd.DataFrame(report).transpose()\n","\n"," return df_report, real_values, predicts"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":81,"referenced_widgets":["df3cbd2c4920474bad84ec99eda619c0","85effd0ef07646eebdd4f3f11600d7a9","20b90fbcd3e847e9ba32754533006005","17b8728e992b4492b462fc43e818aeb0","3f8b887fb7734f5c8f6039635e56bd1e","fdebb25d116543588433193b23754ee3","19bbaf575d524985a9500f970cc26ae5","fd5ffeec51ad471880fe0727281064bb","33f3c907f87d457f9668159dc751474d","5a1c448958014944975d9aa23ea6e731","c861c40db6514b1c8edeef8807893947","2596e14f143b4b84993ee6c11f2a70c1","e851ec492a5f4e55aabab75132452164","9806430ca6034e7993a4a10db28688a8","244f1f1d736748fab96fca0e84953a81","82c39e0cbb9349e085d2cd661fb653f5","baefa134fff04b9f930d9970256453b7","ba6cdff06d4f4813989193ada3e9efb6","e0a021045cb74752b3d4766d26226cc0","25086f17dc4747e38effd3c43acff41f","37c0a263d21046758c860ab98154dc0a","43291225a3294dd0a90d480264337e72"]},"id":"p7laJqSVOl96","outputId":"6739c89c-841d-4339-d811-c26e36516e6b"},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"df3cbd2c4920474bad84ec99eda619c0","version_major":2,"version_minor":0},"text/plain":["config.json: 0%| | 0.00/916 [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"2596e14f143b4b84993ee6c11f2a70c1","version_major":2,"version_minor":0},"text/plain":["pytorch_model.bin: 0%| | 0.00/637M [00:00, ?B/s]"]},"metadata":{},"output_type":"display_data"}],"source":["# Create dataloader for test set\n","test_dataset = NewsDataset(test_df, tokenizer, max_len)\n","test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=2)\n","df_report, real_values, predicts = test(test_loader)"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":551},"id":"3PfIY0jkg16s","outputId":"7ea15cdf-afb6-4e0b-9be3-c91653646a30"},"outputs":[{"data":{"application/vnd.google.colaboratory.intrinsic+json":{"summary":"{\n \"name\": \"df_report\",\n \"rows\": 16,\n \"fields\": [\n {\n \"column\": \"precision\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.03124322214613196,\n \"min\": 0.8457655636567583,\n \"max\": 0.9698681732580038,\n \"num_unique_values\": 16,\n \"samples\": [\n 0.9432234432234432,\n 0.8915956151035322,\n 0.9139344262295082\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"recall\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.03650359365888964,\n \"min\": 0.8543909348441926,\n \"max\": 0.9826187717265353,\n \"num_unique_values\": 15,\n \"samples\": [\n 0.9826187717265353,\n 0.8543909348441926,\n 0.9426479560707749\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"f1-score\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.0333532065632639,\n \"min\": 0.8500563697857947,\n \"max\": 0.974152785755313,\n \"num_unique_values\": 16,\n \"samples\": [\n 0.9429356118400978,\n 0.8870039382005452,\n 0.8946840521564694\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"support\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 7668.022603263057,\n \"min\": 0.9163972477824753,\n \"max\": 24126.0,\n \"num_unique_values\": 15,\n \"samples\": [\n 1726.0,\n 1765.0,\n 1639.0\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}","type":"dataframe","variable_name":"df_report"},"text/html":["\n","