File size: 14,223 Bytes
c7c2507 |
1 |
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":10605153,"sourceType":"datasetVersion","datasetId":6564827}],"dockerImageVersionId":30840,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import warnings\nimport os\n\nos.environ[\"WANDB_DISABLED\"] = \"true\"\nwarnings.filterwarnings(\"ignore\")","metadata":{"_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","trusted":true,"execution":{"iopub.status.busy":"2025-01-29T00:38:58.678784Z","iopub.execute_input":"2025-01-29T00:38:58.679128Z","iopub.status.idle":"2025-01-29T00:38:58.683044Z","shell.execute_reply.started":"2025-01-29T00:38:58.679099Z","shell.execute_reply":"2025-01-29T00:38:58.682291Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import pandas as pd\nimport numpy as np\nimport torch\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import f1_score\nfrom transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer\nfrom torch.utils.data import Dataset","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T00:39:01.650272Z","iopub.execute_input":"2025-01-29T00:39:01.650566Z","iopub.status.idle":"2025-01-29T00:39:01.654590Z","shell.execute_reply.started":"2025-01-29T00:39:01.650544Z","shell.execute_reply":"2025-01-29T00:39:01.653777Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"data = pd.read_csv('/kaggle/input/ru-go-emotions-raw/ru-go-emotions-raw.csv')\nemotion_columns = data.columns[10:]\ndata = data.dropna(subset=['ru_text']) ","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T00:39:04.795445Z","iopub.execute_input":"2025-01-29T00:39:04.795746Z","iopub.status.idle":"2025-01-29T00:39:06.458939Z","shell.execute_reply.started":"2025-01-29T00:39:04.795722Z","shell.execute_reply":"2025-01-29T00:39:06.458241Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"train_texts, val_texts, train_labels, val_labels = train_test_split(\n data['ru_text'].tolist(),\n data[emotion_columns].values,\n test_size=0.1,\n random_state=42\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T00:39:07.803815Z","iopub.execute_input":"2025-01-29T00:39:07.804102Z","iopub.status.idle":"2025-01-29T00:39:07.976716Z","shell.execute_reply.started":"2025-01-29T00:39:07.804080Z","shell.execute_reply":"2025-01-29T00:39:07.975808Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"tokenizer = AutoTokenizer.from_pretrained(\"cointegrated/rubert-tiny2\")\n\ndef tokenize_function(texts):\n return tokenizer(texts, padding=\"max_length\", truncation=True, max_length=128)\n\ntrain_encodings = tokenize_function(train_texts)\nval_encodings = tokenize_function(val_texts)\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T00:39:10.715401Z","iopub.execute_input":"2025-01-29T00:39:10.715712Z","iopub.status.idle":"2025-01-29T00:39:25.842366Z","shell.execute_reply.started":"2025-01-29T00:39:10.715688Z","shell.execute_reply":"2025-01-29T00:39:25.841676Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"class EmotionDataset(Dataset):\n def __init__(self, encodings, labels):\n self.encodings = encodings\n self.labels = labels\n\n def __len__(self):\n return len(self.labels)\n\n def __getitem__(self, idx):\n item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n item['labels'] = torch.tensor(self.labels[idx]).float()\n return item\n\ntrain_dataset = EmotionDataset(train_encodings, train_labels)\nval_dataset = EmotionDataset(val_encodings, val_labels)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T00:39:25.843495Z","iopub.execute_input":"2025-01-29T00:39:25.843806Z","iopub.status.idle":"2025-01-29T00:39:25.849286Z","shell.execute_reply.started":"2025-01-29T00:39:25.843776Z","shell.execute_reply":"2025-01-29T00:39:25.848623Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"model = AutoModelForSequenceClassification.from_pretrained(\"cointegrated/rubert-tiny2\", num_labels=len(emotion_columns))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T00:39:25.850691Z","iopub.execute_input":"2025-01-29T00:39:25.850970Z","iopub.status.idle":"2025-01-29T00:39:27.423163Z","shell.execute_reply.started":"2025-01-29T00:39:25.850949Z","shell.execute_reply":"2025-01-29T00:39:27.422523Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"training_args = TrainingArguments(\n output_dir=\"./results\",\n evaluation_strategy=\"epoch\",\n save_strategy=\"epoch\",\n logging_dir=\"./logs\",\n logging_steps=100,\n per_device_train_batch_size=16,\n per_device_eval_batch_size=16,\n num_train_epochs=1,\n weight_decay=0.01,\n learning_rate=1e-5,\n save_total_limit=2,\n load_best_model_at_end=True,\n metric_for_best_model=\"f1\",\n greater_is_better=True,\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:22:49.776926Z","iopub.execute_input":"2025-01-29T03:22:49.777272Z","iopub.status.idle":"2025-01-29T03:22:49.807148Z","shell.execute_reply.started":"2025-01-29T03:22:49.777216Z","shell.execute_reply":"2025-01-29T03:22:49.806312Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def compute_metrics(eval_pred):\n logits, labels = eval_pred\n preds = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()\n f1 = f1_score(labels, preds, average=\"weighted\")\n return {\"f1\": f1}","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T00:40:12.373628Z","iopub.execute_input":"2025-01-29T00:40:12.373925Z","iopub.status.idle":"2025-01-29T00:40:12.378434Z","shell.execute_reply.started":"2025-01-29T00:40:12.373903Z","shell.execute_reply":"2025-01-29T00:40:12.377447Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"trainer = Trainer(\n model=model,\n args=training_args,\n train_dataset=train_dataset,\n eval_dataset=val_dataset,\n tokenizer=tokenizer,\n compute_metrics=compute_metrics,\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:22:57.004354Z","iopub.execute_input":"2025-01-29T03:22:57.004645Z","iopub.status.idle":"2025-01-29T03:22:57.015351Z","shell.execute_reply.started":"2025-01-29T03:22:57.004620Z","shell.execute_reply":"2025-01-29T03:22:57.014668Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"trainer.train()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:22:58.692598Z","iopub.execute_input":"2025-01-29T03:22:58.692886Z","iopub.status.idle":"2025-01-29T03:27:33.373784Z","shell.execute_reply.started":"2025-01-29T03:22:58.692863Z","shell.execute_reply":"2025-01-29T03:27:33.373175Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"model.save_pretrained(\"./emotion_model\")\ntokenizer.save_pretrained(\"./emotion_model\")","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:17:36.721898Z","iopub.execute_input":"2025-01-29T03:17:36.722196Z","iopub.status.idle":"2025-01-29T03:17:37.192040Z","shell.execute_reply.started":"2025-01-29T03:17:36.722174Z","shell.execute_reply":"2025-01-29T03:17:37.191318Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"def predict_emotions(\n text, \n model, \n tokenizer, \n emotion_columns, \n device=\"cuda\" if torch.cuda.is_available() else \"cpu\", \n threshold=0.1\n):\n\n emotion_translations = {\n \"admiration\": \"восхищение\",\n \"amusement\": \"веселье\",\n \"anger\": \"злость\",\n \"annoyance\": \"раздражение\",\n \"approval\": \"одобрение\",\n \"caring\": \"забота\",\n \"confusion\": \"непонимание\",\n \"curiosity\": \"любопытство\",\n \"desire\": \"желание\",\n \"disappointment\": \"разочарование\",\n \"disapproval\": \"неодобрение\",\n \"disgust\": \"отвращение\",\n \"embarrassment\": \"смущение\",\n \"excitement\": \"возбуждение\",\n \"fear\": \"страх\",\n \"gratitude\": \"признательность\",\n \"grief\": \"горе\",\n \"joy\": \"радость\",\n \"love\": \"любовь\",\n \"nervousness\": \"нервозность\",\n \"optimism\": \"оптимизм\",\n \"pride\": \"гордость\",\n \"realization\": \"осознание\",\n \"relief\": \"облегчение\",\n \"remorse\": \"раскаяние\",\n \"sadness\": \"грусть\",\n \"surprise\": \"удивление\",\n \"neutral\": \"нейтральность\",\n }\n\n model.to(device)\n model.eval()\n inputs = tokenizer(text, return_tensors=\"pt\", padding=\"max_length\", truncation=True, max_length=128).to(device)\n with torch.no_grad():\n logits = model(**inputs).logits\n probabilities = torch.sigmoid(logits).squeeze().cpu().numpy()\n\n predictions = {\n f\"{emotion} ({emotion_translations[emotion]})\": prob\n for emotion, prob in zip(emotion_columns, probabilities) if prob > threshold\n }\n\n sorted_predictions = dict(sorted(predictions.items(), key=lambda item: item[1], reverse=True))\n \n return sorted_predictions\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:17:48.694048Z","iopub.execute_input":"2025-01-29T03:17:48.694355Z","iopub.status.idle":"2025-01-29T03:17:48.701764Z","shell.execute_reply.started":"2025-01-29T03:17:48.694331Z","shell.execute_reply":"2025-01-29T03:17:48.700844Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"example_text = \"Как же я рад!\"\npredictions = predict_emotions(example_text, model, tokenizer, emotion_columns)\nprint(\"Emotions:\", predictions)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:28:41.045221Z","iopub.execute_input":"2025-01-29T03:28:41.045575Z","iopub.status.idle":"2025-01-29T03:28:41.057190Z","shell.execute_reply.started":"2025-01-29T03:28:41.045547Z","shell.execute_reply":"2025-01-29T03:28:41.056478Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"example_text = \"Как же я не рад!\"\npredictions = predict_emotions(example_text, model, tokenizer, emotion_columns)\nprint(\"Emotions:\", predictions)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:28:49.329257Z","iopub.execute_input":"2025-01-29T03:28:49.329561Z","iopub.status.idle":"2025-01-29T03:28:49.341461Z","shell.execute_reply.started":"2025-01-29T03:28:49.329537Z","shell.execute_reply":"2025-01-29T03:28:49.340540Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"example_text = \"Всё очень плохо!\"\npredictions = predict_emotions(example_text, model, tokenizer, emotion_columns)\nprint(\"Emotions:\", predictions)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:28:59.880762Z","iopub.execute_input":"2025-01-29T03:28:59.881090Z","iopub.status.idle":"2025-01-29T03:28:59.892709Z","shell.execute_reply.started":"2025-01-29T03:28:59.881059Z","shell.execute_reply":"2025-01-29T03:28:59.891839Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"example_text = \"ого! вот это да!\"\npredictions = predict_emotions(example_text, model, tokenizer, emotion_columns)\nprint(\"Emotions:\", predictions)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:18:32.552701Z","iopub.execute_input":"2025-01-29T03:18:32.553029Z","iopub.status.idle":"2025-01-29T03:18:32.563438Z","shell.execute_reply.started":"2025-01-29T03:18:32.553001Z","shell.execute_reply":"2025-01-29T03:18:32.562734Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"trainer.save_model()\n","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:18:41.909667Z","iopub.execute_input":"2025-01-29T03:18:41.909954Z","iopub.status.idle":"2025-01-29T03:18:42.391982Z","shell.execute_reply.started":"2025-01-29T03:18:41.909932Z","shell.execute_reply":"2025-01-29T03:18:42.391053Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import os\nimport subprocess\nfrom IPython.display import FileLink, display\n\ndef download_file(path, download_file_name):\n os.chdir('/kaggle/working/')\n zip_name = f\"/kaggle/working/{download_file_name}.zip\"\n command = f\"zip {zip_name} {path} -r\"\n result = subprocess.run(command, shell=True, capture_output=True, text=True)\n if result.returncode != 0:\n print(\"Unable to run zip command!\")\n print(result.stderr)\n return\n display(FileLink(f'{download_file_name}.zip'))","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T02:29:26.763441Z","iopub.execute_input":"2025-01-29T02:29:26.763737Z","iopub.status.idle":"2025-01-29T02:29:26.768654Z","shell.execute_reply.started":"2025-01-29T02:29:26.763713Z","shell.execute_reply":"2025-01-29T02:29:26.767838Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"download_file('/kaggle/working/emotion_model', 'Emotions_model-04')","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-01-29T03:18:52.390465Z","iopub.execute_input":"2025-01-29T03:18:52.390787Z","iopub.status.idle":"2025-01-29T03:18:58.791398Z","shell.execute_reply.started":"2025-01-29T03:18:52.390758Z","shell.execute_reply":"2025-01-29T03:18:58.790673Z"}},"outputs":[],"execution_count":null}]} |