{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VtuWHaKEQdEq", "outputId": "9f28174a-c296-4af7-a700-e143970403e1" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.42.4)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.15.4)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.5)\n", "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.4)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (2024.6.1)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.2->transformers) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.7.4)\n", "Collecting datasets\n", " Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.15.4)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", "Collecting pyarrow>=15.0.0 (from datasets)\n", " Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)\n", "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.5)\n", "Collecting xxhash (from datasets)\n", " Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", "Collecting multiprocess (from datasets)\n", " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.2)\n", "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.23.5)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.3.5)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.2->datasets) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.7.4)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", "Downloading datasets-2.21.0-py3-none-any.whl (527 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m527.3/527.3 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m9.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.9/39.9 MB\u001b[0m \u001b[31m19.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: xxhash, pyarrow, dill, multiprocess, datasets\n", " Attempting uninstall: pyarrow\n", " Found existing installation: pyarrow 14.0.2\n", " Uninstalling pyarrow-14.0.2:\n", " Successfully uninstalled pyarrow-14.0.2\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "cudf-cu12 24.4.1 requires pyarrow<15.0.0a0,>=14.0.1, but you have pyarrow 17.0.0 which is incompatible.\n", "ibis-framework 8.0.0 requires pyarrow<16,>=2, but you have pyarrow 17.0.0 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0mSuccessfully installed datasets-2.21.0 dill-0.3.8 multiprocess-0.70.16 pyarrow-17.0.0 xxhash-3.4.1\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.1+cu121)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n", "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)\n", " Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)\n", " Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)\n", " Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)\n", " Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)\n", " Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)\n", " Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-curand-cu12==10.3.2.106 (from torch)\n", " Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cusolver-cu12==11.4.5.107 (from torch)\n", " Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cusparse-cu12==12.1.0.106 (from torch)\n", " Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-nccl-cu12==2.20.5 (from torch)\n", " Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl.metadata (1.8 kB)\n", "Collecting nvidia-nvtx-cu12==12.1.105 (from torch)\n", " Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.7 kB)\n", "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.1)\n", "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch)\n", " Using cached nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", "Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", "Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", "Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", "Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", "Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", "Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", "Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", "Using cached nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", "Using cached nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", "Using cached nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)\n", "Using cached nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", "Using cached nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl (19.7 MB)\n", "Installing collected packages: nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n", "Successfully installed nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.20.5 nvidia-nvjitlink-cu12-12.6.20 nvidia-nvtx-cu12-12.1.105\n" ] } ], "source": [ "!pip install transformers\n", "!pip install datasets\n", "!pip install torch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1vXYwybdRDD_", "outputId": "19a76d2e-3b8e-4aba-a6a1-2e410b5806bb" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.42.4)\n", "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.21.0)\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.3.1+cu121)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.15.4)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.23.2 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.5)\n", "Requirement already satisfied: numpy<2.0,>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.26.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.4)\n", "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.5)\n", "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.1.4)\n", "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.6.1,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.6.1,>=2023.1.0->datasets) (2024.6.1)\n", "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.10.2)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch) (8.9.2.26)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.3.1)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch) (11.0.2.54)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch) (10.3.2.106)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch) (11.4.5.107)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.0.106)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch) (2.20.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch) (12.1.105)\n", "Requirement already satisfied: triton==2.3.1 in /usr/local/lib/python3.10/dist-packages (from torch) (2.3.1)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.20)\n", "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.3.5)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.7.4)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" ] } ], "source": [ "!pip install transformers datasets torch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "5tiNer2wNmKd", "outputId": "9cb6a807-b20b-4807-c64e-71bd89b09f3a" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.6.17)\n", "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n", "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2024.7.4)\n", "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.8.2)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.32.3)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.5)\n", "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.4)\n", "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.0.7)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n", "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.7)\n", "cp: cannot stat 'kaggle.json': No such file or directory\n", "chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory\n", "Dataset URL: https://www.kaggle.com/datasets/rmisra/imdb-spoiler-dataset\n", "License(s): Attribution 4.0 International (CC BY 4.0)\n", "Downloading imdb-spoiler-dataset.zip to /content\n", " 98% 325M/331M [00:02<00:00, 136MB/s]\n", "100% 331M/331M [00:02<00:00, 138MB/s]\n", "Archive: imdb-spoiler-dataset.zip\n", " inflating: IMDB_movie_details.json \n", " inflating: IMDB_reviews.json \n", "IMDB_movie_details.json IMDB_reviews.json imdb-spoiler-dataset.zip sample_data\n" ] } ], "source": [ "!pip install kaggle\n", "\n", "!mkdir -p ~/.kaggle\n", "!cp kaggle.json ~/.kaggle/\n", "\n", "!chmod 600 ~/.kaggle/kaggle.json\n", "\n", "!kaggle datasets download -d rmisra/imdb-spoiler-dataset\n", "\n", "!unzip imdb-spoiler-dataset.zip\n", "\n", "!ls" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "uXLdvqgVOL99", "outputId": "d6231cf9-2178-4c04-b880-53ef40842eea" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " movie_id plot_summary duration \\\n", "0 tt0105112 Former CIA analyst, Jack Ryan is in England wi... 1h 57min \n", "1 tt1204975 Billy (Michael Douglas), Paddy (Robert De Niro... 1h 45min \n", "2 tt0243655 The setting is Camp Firewood, the year 1981. I... 1h 37min \n", "3 tt0040897 Fred C. Dobbs and Bob Curtin, both down on the... 2h 6min \n", "4 tt0126886 Tracy Flick is running unopposed for this year... 1h 43min \n", "\n", " genre rating release_date \\\n", "0 [Action, Thriller] 6.9 1992-06-05 \n", "1 [Comedy] 6.6 2013-11-01 \n", "2 [Comedy, Romance] 6.7 2002-04-11 \n", "3 [Adventure, Drama, Western] 8.3 1948-01-24 \n", "4 [Comedy, Drama, Romance] 7.3 1999-05-07 \n", "\n", " plot_synopsis \n", "0 Jack Ryan (Ford) is on a \"working vacation\" in... \n", "1 Four boys around the age of 10 are friends in ... \n", "2 \n", "3 Fred Dobbs (Humphrey Bogart) and Bob Curtin (T... \n", "4 Jim McAllister (Matthew Broderick) is a much-a... \n", " review_date movie_id user_id is_spoiler \\\n", "0 10 February 2006 tt0111161 ur1898687 True \n", "1 6 September 2000 tt0111161 ur0842118 True \n", "2 3 August 2001 tt0111161 ur1285640 True \n", "3 1 September 2002 tt0111161 ur1003471 True \n", "4 20 May 2004 tt0111161 ur0226855 True \n", "\n", " review_text rating \\\n", "0 In its Oscar year, Shawshank Redemption (writt... 10 \n", "1 The Shawshank Redemption is without a doubt on... 10 \n", "2 I believe that this film is the best story eve... 8 \n", "3 **Yes, there are SPOILERS here**This film has ... 10 \n", "4 At the heart of this extraordinary movie is a ... 8 \n", "\n", " review_summary \n", "0 A classic piece of unforgettable film-making. \n", "1 Simply amazing. The best film of the 90's. \n", "2 The best story ever told on film \n", "3 Busy dying or busy living? \n", "4 Great story, wondrously told and acted \n" ] } ], "source": [ "import pandas as pd\n", "\n", "# Load JSON data by reading lines if the data is stored in JSON Lines format\n", "movie_details = pd.read_json('IMDB_movie_details.json', lines=True)\n", "reviews = pd.read_json('IMDB_reviews.json', lines=True)\n", "\n", "# Check the first few entries to ensure data is loaded correctly\n", "print(movie_details.head())\n", "print(reviews.head())\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "44kpxJ0iOjLz" }, "outputs": [], "source": [ "# Drop rows where 'plot_synopsis' or 'plot_summary' is missing\n", "movie_details.dropna(subset=['plot_synopsis', 'plot_summary'], inplace=True)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qaAPgonAPRGx" }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "\n", "# Split the data into training and test sets\n", "from sklearn.model_selection import train_test_split\n", "\n", "# First split: split into training and temp data (which will become validation and test sets)\n", "train_data, temp_data = train_test_split(movie_details, test_size=0.3, random_state=42) # 70% for training, 30% for temp\n", "\n", "# Second split: split the temp data into validation and test sets\n", "validation_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42) # Split the 30% into 15% validation and 15% test\n", "\n", "# Now, train_data holds 70% of the data, validation_data holds 15%, and test_data also holds 15%\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gewp_maVVU6g", "outputId": "281cc145-2564-466c-b526-1ad7f58957c1" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "(1100, 7)" ] }, "metadata": {}, "execution_count": 7 } ], "source": [ "train_data.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "6Y-syQ3ZPVDR", "outputId": "544030c8-9652-4ec6-96f3-b3dbd03a82a4" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "count 1572.000000\n", "mean 1439.085242\n", "std 1496.392929\n", "min 0.000000\n", "25% 493.750000\n", "50% 1073.500000\n", "75% 1920.000000\n", "max 11396.000000\n", "Name: synopsis_length, dtype: float64\n" ] } ], "source": [ "# Example exploration: Average length of plot_synopsis\n", "movie_details['synopsis_length'] = movie_details['plot_synopsis'].apply(lambda x: len(x.split()))\n", "print(movie_details['synopsis_length'].describe())\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6lRJv7A7PX8x" }, "outputs": [], "source": [ "train_data.to_json('/content/train_data.json', orient='records', lines=True)\n", "validation_data.to_json('/content/val_data.json', orient='records', lines=True)\n", "test_data.to_json('/content/test_data.json', orient='records', lines=True)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 113, "referenced_widgets": [ "7e20cd41d4d7401284eae05380531c86", "1f7a1424873e47f9882ce9aadd360be9", "3ba09083cc8f403c8231adb1dddeffcd", "bf60ee240ad141e5b7c14a31caf459f6", "46535b3c982d45aeadcaa00f2e024ad1", "e19c05d1e02546ddad481a7d5a5917cd", "ab60068105e149eb9318b3a8b3662075", "6ec9693194af46ffbd8310d83917803b", "5f88faea4b894a6e8cb44ce00084e2a2", "eeecd29be4284e0cb9b3ddf9d83de029", "43830d9d88dd47bfb4ce27020a2be9f3", "481fda13f40a4ea19853b2b0cf0fa99e", "a16b6c84f2c74226a46b1e9724965cca", "2e55a4a146454d2faacfbfe352ecd917", "c8b01fd70f364abdb2da85e135dfc5d2", "974c5fa2dc1c47ae95b8debec2640825", "e98b2b076dba4a8ab1942dc0e23984d8", "2feee173fbe44801bbc9dc05952e1038", "ffb79557f11948d69d0b67f2a71121e8", "91ef540aa06b43d28c6bd371665dcb68", "e76db42f1f9e4edda5ff102e8cdd7363", "9330377c22224352b99bba7cd4d635b2", "3de6d461f83a4c5f8924eb02b9f13d84", "d4adf688973c48a693a7a4aee8e0b23e", "756a26793f594047b99f5642a59b6dd9", "924058735d454e47a76f36c558a7ba15", "edfc526ced1b4f799f4609a4289e369b", "b2ebfcadca8f4017ac1606caeba957ed", "ec119eb9de4742d180db9862b5369b70", "8a32a28a12464bb683c62d10a9663f1c", "5317f74358914b1783f2bb812006cd8d", "9293e46df3f14358b87af38eeef7f109", "24c79207d28d495eaeabd38db4f6007b" ] }, "id": "yj5INXEXPhsr", "outputId": "c1d93ee4-7dbc-45ed-c12b-f8cca44398ed" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "Generating train split: 0 examples [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "7e20cd41d4d7401284eae05380531c86" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Generating test split: 0 examples [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "481fda13f40a4ea19853b2b0cf0fa99e" } }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "Generating val split: 0 examples [00:00, ? examples/s]" ], "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "3de6d461f83a4c5f8924eb02b9f13d84" } }, "metadata": {} } ], "source": [ "from datasets import load_dataset\n", "\n", "data_files = {\n", " 'train': '/content/train_data.json',\n", " 'test': '/content/test_data.json',\n", " 'val': '/content/val_data.json'\n", "\n", "\n", "}\n", "\n", "# Loading the dataset from the JSON files\n", "dataset = load_dataset('json', data_files=data_files, split={'train': 'train', 'test': 'test', 'val': 'val'})\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 428, "referenced_widgets": [ "57fcbc7a221e4441a1908c756e3b2ca0", "b733e988bb9542759d8d834ff6b1d5f3", "a7e4c7ab193e4300acc49050c2004b1b", "ae8003c16eca4d59962125d599f79f31", "ef6d361ff5e34115b55443e499685937", "f73ebe2ab28242efbb4450a6e329086a", "2181e80c121d4da8bd422e1d2be8d5b8", "beac2ecdbef440e799f4a1364f7e8856", "bd75bd7a93084f39b1cbe9d2b789c3fc", "43a6048276494abb814a078fe05c439f", "80cdedbc980d490fae438a6fb92e0e66", "f542f154196d4ef2bf67ffbdc488b8ad", "83772cf85c604db19217acb073ff4f78", "d42835fe1966412cbaa21333ecd361cc", "222c0334528541d5b4470801a75b8d0e", "4107eaad71e04e968c9d4123b7e70312", "5274a20615a949c68eee15c066f845c1", "5193d8adb6124549a4d9471c2475f1ec", "11dc881125294407b118b31d01b92d70", "14435f717b2c43e59861a7b404d11b53", "77f710ea78b04ba3831b6f95f156fe31", "9ede3c68111047d484edfbd67c82c50d", "1d5b88a9e165423a815479dd4138bf2b", "6b93a696e83e401787113078df83baa7", "654e846c19a441d18ea81a12350c9c51", "0f751527e05e4293be2b2312cc52d754", "a54d5c1d98cb46a68a38da67d9118658", "9c374f90a9b74b208c54f4958d40a12c", "21c8478187654918ab9c689fb31ee16e", "63ede400669443c78551c04691c422e6", "fc64e38fac154cc3b48989a126d73699", "f285da915ded4694aa9b41c03f3bd289", "be05ebd4527c4bb08da9e0ff2e7b892e" ] }, "id": "GOe61CxiYkgv", "outputId": "52095422-f4f9-4053-caf3-40ac92c981ae" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "tokenizer_config.json: 0%| | 0.00/2.32k [00:00. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n", "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "count 1572.000000\n", "mean 2116.767812\n", "std 2197.103130\n", "min 0.000000\n", "25% 729.000000\n", "50% 1581.000000\n", "75% 2814.750000\n", "max 18103.000000\n", "Name: token_length, dtype: float64\n" ] } ], "source": [ "from transformers import T5Tokenizer\n", "\n", "\n", "import pandas as pd\n", "\n", "movie_details = pd.read_json('/content/IMDB_movie_details.json', lines=True)\n", "\n", "# Initialize the tokenizer\n", "tokenizer = T5Tokenizer.from_pretrained('t5-small')\n", "\n", "# Function to calculate token length\n", "def calculate_token_length(text):\n", " return len(tokenizer.tokenize(text))\n", "\n", "# Apply the function to the plot_synopsis column\n", "movie_details['token_length'] = movie_details['plot_synopsis'].apply(calculate_token_length)\n", "\n", "# Display statistics about the token lengths\n", "stats = movie_details['token_length'].describe()\n", "print(stats)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QHSO1Tz4Z5kc", "outputId": "d64d5657-b3c9-42b1-e5bb-c47067593d30" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "count 1572.000000\n", "mean 151.191476\n", "std 60.718672\n", "min 20.000000\n", "25% 103.000000\n", "50% 142.000000\n", "75% 195.250000\n", "max 315.000000\n", "Name: token_length, dtype: float64\n" ] } ], "source": [ "# Apply the function to the plot_synopsis column\n", "movie_details['token_length'] = movie_details['plot_summary'].apply(calculate_token_length)\n", "\n", "# Display statistics about the token lengths\n", "stats = movie_details['token_length'].describe()\n", "print(stats)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "VS4ROiIKnvLy" }, "outputs": [], "source": [ "device = 'cuda'\n" ] }, { "cell_type": "markdown", "metadata": { "id": "dTcbyQKVQFFA" }, "source": [ "##**Preprocess the Data**" ] }, { "cell_type": "code", "source": [ "import re\n", "import torch\n", "from transformers import LEDForConditionalGeneration, LEDTokenizer\n", "\n", "# Load tokenizer and model\n", "tokenizer = LEDTokenizer.from_pretrained('allenai/led-base-16384')\n", "model = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384')\n", "\n", "model = model.to(device)\n", "\n", "# Function to normalize text\n", "def normalize_text(text):\n", " text = text.lower() # Lowercase the text\n", " text = re.sub(r'\\s+', ' ', text).strip() # Remove extra spaces and newlines\n", " text = re.sub(r'[^\\w\\s]', '', text) # Remove non-alphanumeric characters\n", " return text\n", "\n", "# Preprocess function with normalization\n", "def preprocess_function(examples):\n", " # Normalize the plot_synopsis and plot_summary\n", " inputs = [\"summarize: \" + normalize_text(doc) for doc in examples[\"plot_synopsis\"]]\n", " model_inputs = tokenizer(inputs, max_length=3000, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", "\n", " # Normalize labels (plot_summary)\n", " with tokenizer.as_target_tokenizer():\n", " labels = tokenizer([normalize_text(doc) for doc in examples[\"plot_summary\"]], max_length=1024, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", "\n", " # Replace -100 for padding tokens in labels\n", " labels[\"input_ids\"] = [\n", " [(label if label != tokenizer.pad_token_id else -100) for label in lab]\n", " for lab in labels[\"input_ids\"]\n", " ]\n", "\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs" ], "metadata": { "id": "eBxLkLaUHJLI" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NUPy3CuAQBGR", "colab": { "base_uri": "https://localhost:8080/", "height": 319, "referenced_widgets": [ "83ee5845ab11475c822bef2a97d9985e", "df49216e7e97470c8c7116f81bd9088a", "4fbf399a047a42389247bf044ed0407a", "165bbe074c97440387978c5489cc01c0", "84ab4978aa1c4fc19359fd7890b2ff93", "f704393ca576446dbe26a9d957d1b8af", "ca6ac6fa765c42e38ede8f42ffaecd89", "aacf1d2128a9422b948d5fa8cda21992", "75c04f957217426d970c64e1373292a0", "e5fbed894a9944179c8d98abf67c9db9", "f6fcf54ded9f48a89882d2054e421b8f", "f04ea068b2bc4462bb24d8ef032bb244", "38b648064bc84d1fa81de179271b1a3a", "7090cc633e4d4fc5a64dcd5c5441edfb", "b35d5fcaf2824c729229736ea2c64bfd", "f78f2e10e3364c6aa336a71b2c8e1375", "8c998cbdd0d248b588566b6905691ccb", "c0ccc4d33aa14e53a763de90417d512d", "70ea062d17e444b19aecda8a98090ac2", "3fb5e4ae669f44b19418ddd3be5e3f6c", "8674604874294cdda360acfa42281143", "a57f1446a67a4e749a6aa15c03942e05", "bc77626dc02841b1a4a8bb1fd736b069", "c2cab2d64c094889a3fff1088c2fcc98", "0abb861eef7c4ef983e9415347d534a0", "e9f3c13b73624b48abbc80c5bea466da", "66fa95cffd024444bcfb222ea43e7f27", "aa62cd6d9c9848e2a4d78d905540e6c3", "18c95b540b7644658719d73e7f3e901b", "acf2edbc822a4370879ed621aad67c32", "5aaea6aede184582bf0d0157039af114", "3ebd4a88dc0d483d973a35add37e5772", "7831e2aa3c644d0ea83694f4a6816727", "9c79e38bc1454c52865a0628ff687276", "820758d6fa1f4c07a9748e77dab01311", "32b8b010edde4ca7ac47047d1811f38b", "879091262d234c6ab761dd712d48d4dc", "44df2ef9fe804e4798e0bafcaa2f96f8", "a8e7085f3a1e46f495785b0aa330cae2", "28e86cffe6d94da19892b3a1aab1da14", "ec1275ef768f43698c342a9dc81fafcc", "a1328d66ceeb4e3c90ee715a562750de", "e92c2c6741384fe38ce8a7ec42344fbb", "4c212586a87d43b09d7085dee3cd1f98", "3ad8fcdbf3cf4234bef3294087c250bf", "bff77597434e408cb5a74dfb49121408", "2b257a4d028145e284cf91ccdb341002", "5818389eeb52497ea30392c67f4ee098", "6ea4fd420cf24e0eb06597d90dfa5a6f", "5cae3942022845bca747d27c148f8013", "b803692080204ebc923925dbeee49728", "8c70de273ecf42d1ab0541d613122974", "82382da94cf14bbbb08f7879dbe8ee7b", "8a4b4e7b68fa44df8c793e01961e11b2", "29e4073fe6a94b62b3d10419c2a76a52", "06a0ac0b2ac847f39fc4123df7b6c0cb", "ce391154f73d4c03bf5ca5a770f4528d", "0ec1f324dcd046fb8c6e1115087184bd", "57d530b1bd464a5da1c6fa3b05b24fc7", "aa463fb674994d449a5a8fc7ccde4a4d" ] }, "outputId": "d98eb8e3-3eb7-42ef-86b5-f75c0723d804" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "83ee5845ab11475c822bef2a97d9985e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "tokenizer_config.json: 0%| | 0.00/27.0 [00:00\n", " \n", " \n", " [2001/3300 47:25 < 30:49, 0.70 it/s, Epoch 1.82/3]\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
12.9253002.873685

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "

\n", " \n", " \n", " [3300/3300 1:23:26, Epoch 3/3]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
12.9253002.873685
22.4096002.883492
32.1653002.920285

" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=3300, training_loss=2.429696747750947, metrics={'train_runtime': 5007.437, 'train_samples_per_second': 0.659, 'train_steps_per_second': 0.659, 'total_flos': 6526373990400000.0, 'train_loss': 2.429696747750947, 'epoch': 3.0})" ] }, "metadata": {}, "execution_count": 21 } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gt4LnnFMTTsV", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e513a352-837f-4d06-c635-fac1596c342e" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/drive\n" ] } ], "source": [ "# trainer.train()\n", "# trainer.save_model(\"/content/t5_spoiler_free_summarization\")\n", "# prompt: can you also save it in my driver?\n", "\n", "from google.colab import drive\n", "drive.mount('/content/drive')\n", "\n", "# Assuming the model is already trained and 'model' variable holds the trained model\n", "model_save_path = '/content/drive/MyDrive/summary_generation' # Replace with your desired path in Drive\n", "trainer.save_model(model_save_path)\n" ] }, { "cell_type": "code", "source": [], "metadata": { "id": "p1oX9EDsmxPo" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Assuming the model is already trained and 'model' variable holds the trained model\n", "model_save_path = '/content/drive/MyDrive/summary_generation_Led_3' # Replace with your desired path in Drive\n", "trainer.save_model(model_save_path)" ], "metadata": { "id": "kFuNP98wGaa2" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')\n", "\n", "# Load the trained model from Google Drive\n", "model_save_path = '/content/drive/MyDrive/summary_generation' # Replace with your saved path\n", "model = LEDForConditionalGeneration.from_pretrained(model_save_path)\n", "\n", "# Ensure model is on the right device (GPU if available)\n", "model = model.to(device)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "yopwoSEJgNkT", "outputId": "e5cf03ef-5d32-4447-9b6d-237c90b4b3c0" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Mounted at /content/drive\n" ] } ] }, { "cell_type": "code", "source": [ "from transformers import Trainer, TrainingArguments\n", "\n", "# Reload the model from the saved checkpoint in Google Drive\n", "model_save_path = '/content/drive/MyDrive/summary_generation' # Your saved path\n", "model = LEDForConditionalGeneration.from_pretrained(model_save_path)\n", "model = model.to(device) # Ensure the model is on the right device (GPU or CPU)\n", "\n", "# Define the training arguments for continuing training\n", "training_args = TrainingArguments(\n", " output_dir=\"./results\", # Where the model checkpoints will be stored\n", " eval_strategy=\"epoch\", # Use eval_strategy instead of evaluation_strategy\n", " save_strategy=\"epoch\", # Ensure the save strategy matches the evaluation strategy\n", " learning_rate=2e-5, # Adjust learning rate if needed\n", " per_device_train_batch_size=1, # LED is memory-intensive\n", " per_device_eval_batch_size=1,\n", " weight_decay=0.01,\n", " save_total_limit=3, # Only keep the last 3 models saved\n", " num_train_epochs=1, # Continue training for 1 additional epoch\n", " report_to=\"none\",\n", " logging_dir='./logs', # Directory for storing logs\n", " logging_steps=500,\n", " load_best_model_at_end=True, # Load best model at the end of training based on evaluation loss\n", ")\n", "\n", "# Initialize the optimizer (AdamW) again for the new training run\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate)\n", "\n", "# Create the Trainer instance and pass the model, data, and optimizer\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_datasets[\"train\"], # The same training dataset\n", " eval_dataset=tokenized_datasets[\"val\"], # The same validation dataset\n", " tokenizer=tokenizer,\n", " optimizers=(optimizer, None) # Pass the optimizer\n", ")\n", "\n", "# Resume training for another epoch\n", "trainer.train()\n", "\n", "# After training is complete, you can save the updated model\n", "trainer.save_model(\"/content/drive/MyDrive/summary_generation_Led_4\") # Save the new checkpoint\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 144 }, "id": "DnDXbarJm1YG", "outputId": "90c5c5a8-bf06-4631-e868-17c6e7dc2f4a" }, "execution_count": null, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "\n", "

\n", " \n", " \n", " [1100/1100 26:32, Epoch 1/1]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation Loss
12.1405002.964694

" ] }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": [ "There were missing keys in the checkpoint model loaded: ['led.encoder.embed_tokens.weight', 'led.decoder.embed_tokens.weight', 'lm_head.weight'].\n" ] } ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7S7rMmsRzUDK", "outputId": "207767ac-c55d-433b-989f-6f9f15098d16" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Collecting rouge_score\n", " Downloading rouge_score-0.1.2.tar.gz (17 kB)\n", " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.4.0)\n", "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (from rouge_score) (3.8.1)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.26.4)\n", "Requirement already satisfied: six>=1.14.0 in /usr/local/lib/python3.10/dist-packages (from rouge_score) (1.16.0)\n", "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (8.1.7)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (1.4.2)\n", "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (2024.5.15)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk->rouge_score) (4.66.5)\n", "Building wheels for collected packages: rouge_score\n", " Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n", " Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24935 sha256=99631a27f739be8626953540b4e8f0dccd306bb718d08af5f91069d975cf0f26\n", " Stored in directory: /root/.cache/pip/wheels/5f/dd/89/461065a73be61a532ff8599a28e9beef17985c9e9c31e541b4\n", "Successfully built rouge_score\n", "Installing collected packages: rouge_score\n", "Successfully installed rouge_score-0.1.2\n" ] } ], "source": [ "pip install rouge_score" ] }, { "cell_type": "code", "source": [ "pip install nltk\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "OFJmH2KUtdod", "outputId": "8b551282-a6ac-481f-f5fe-d6c702e2ba70" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: nltk in /usr/local/lib/python3.10/dist-packages (3.8.1)\n", "Requirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from nltk) (8.1.7)\n", "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from nltk) (1.4.2)\n", "Requirement already satisfied: regex>=2021.8.3 in /usr/local/lib/python3.10/dist-packages (from nltk) (2024.5.15)\n", "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from nltk) (4.66.5)\n" ] } ] }, { "cell_type": "code", "source": [ "import torch, gc\n", "\n", "gc.collect()\n", "torch.cuda.empty_cache()\n" ], "metadata": { "id": "SEGn9IQ4to0c" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import nltk\n", "\n", "nltk.download('punkt')\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "FliiA5j_-bK6", "outputId": "a1784ce0-e243-4f06-b2dc-554300ec8726" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "[nltk_data] Downloading package punkt to /root/nltk_data...\n", "[nltk_data] Unzipping tokenizers/punkt.zip.\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "True" ] }, "metadata": {}, "execution_count": 20 } ] }, { "cell_type": "code", "source": [ "from datasets import load_metric\n", "import nltk\n", "\n", "metric_rouge = load_metric(\"rouge\")\n", "\n", "def generate_summary(batch):\n", " inputs = tokenizer(batch[\"plot_synopsis\"], max_length=3000, truncation=True, padding=\"max_length\", return_tensors=\"pt\")\n", " inputs = inputs.to(device)\n", " outputs = model.generate(inputs[\"input_ids\"], max_length=315, min_length=20, length_penalty=2.0, num_beams=4, early_stopping=True)\n", "\n", " batch[\"pred_summary\"] = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n", " return batch\n", "\n", "results = dataset[\"test\"].map(generate_summary, batched=True, batch_size=8)\n", "rouge_score = metric_rouge.compute(predictions=results[\"pred_summary\"], references=results[\"plot_summary\"])\n", "print(\"ROUGE scores:\")\n", "print(rouge_score)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 104, "referenced_widgets": [ "ad570df7408d48d99eb69a50ad8f3201", "265654d5426b4616b530e7fd1c7cec04", "ccb396e3506e487284decd31875fa33d", "2873376fa2264b26ab8dcebe41d46af0", "13b89ba59e80499f9d6926ee97d9f519", "edf3f110599943ab822847b15bcfd0d4", "8d275ca50cce4c48889ddc50769ada21", "d89b8b1072d9408c96d2f7b2e735c785", "0e6ac91d7b284b9b82abdedf36e91924", "e0b2c384ca364c49869ae498683a6ef3", "a84982a5e3ff40a18ab9b416cd85a111" ] }, "id": "I1wd38l3sYAy", "outputId": "4a6e0d20-74b4-4916-edf3-c7b0ef4b5117" }, "execution_count": null, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ad570df7408d48d99eb69a50ad8f3201", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/236 [00:00