{ "cells": [ { "cell_type": "code", "source": [ "# Credit\n", "## Base code https://devocean.sk.com/blog/techBoardDetail.do?ID=165703\n", "## Updated by Yunho Maeng, yunhomaeng@yonsei.ac.kr" ], "metadata": { "id": "ABqAyuxVJ81s" }, "id": "ABqAyuxVJ81s", "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "id": "9e02c8d1-e653-41a5-a94f-e44c176dbcc5", "metadata": { "id": "9e02c8d1-e653-41a5-a94f-e44c176dbcc5" }, "source": [ "# 1. 개발 환경 설정" ] }, { "cell_type": "markdown", "id": "9fa242e1-7689-4397-b410-d550e79246c3", "metadata": { "id": "9fa242e1-7689-4397-b410-d550e79246c3" }, "source": [ "### 1.1 필수 라이브러리 설치하기" ] }, { "cell_type": "code", "execution_count": 2, "id": "3d405d7a-f2c9-4416-bf88-880812a2b8b5", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3d405d7a-f2c9-4416-bf88-880812a2b8b5", "outputId": "05d4afab-2d0c-49c3-f85f-b00644997ac6" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.5/8.5 MB\u001b[0m \u001b[31m75.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.6/3.6 MB\u001b[0m \u001b[31m52.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m2.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m15.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m27.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m22.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m105.0/105.0 MB\u001b[0m \u001b[31m16.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m190.9/190.9 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m314.1/314.1 kB\u001b[0m \u001b[31m24.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m155.3/155.3 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m103.4/103.4 kB\u001b[0m \u001b[31m10.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m280.0/280.0 kB\u001b[0m \u001b[31m8.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h" ] } ], "source": [ "!pip3 install -q -U transformers==4.38.2\n", "!pip3 install -q -U datasets==2.18.0\n", "!pip3 install -q -U bitsandbytes==0.42.0\n", "!pip3 install -q -U peft==0.9.0\n", "!pip3 install -q -U trl==0.7.11\n", "!pip3 install -q -U accelerate==0.27.2" ] }, { "cell_type": "code", "source": [ "# 구글 드라이브 마운트\n", "\n", "from google.colab import drive\n", "import shutil\n", "import os\n", "\n", "drive.mount('/content/drive')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SjWkITPKOX0a", "outputId": "4909a8e4-52dd-4e2c-da82-2a338f4342e0" }, "id": "SjWkITPKOX0a", "execution_count": 41, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ] }, { "cell_type": "markdown", "id": "13fa79b6-4720-43d1-baae-41d834011c2c", "metadata": { "id": "13fa79b6-4720-43d1-baae-41d834011c2c" }, "source": [ "### 1.2 Import modules" ] }, { "cell_type": "code", "execution_count": 3, "id": "1d7a17e3-b9a1-4a46-8f6e-7710a37a93bf", "metadata": { "id": "1d7a17e3-b9a1-4a46-8f6e-7710a37a93bf" }, "outputs": [], "source": [ "import torch\n", "from datasets import Dataset, load_dataset\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline, TrainingArguments\n", "from peft import LoraConfig, PeftModel\n", "from trl import SFTTrainer" ] }, { "cell_type": "markdown", "id": "5b7f30d7-bfdf-49c5-8c2c-701ad6f15a80", "metadata": { "id": "5b7f30d7-bfdf-49c5-8c2c-701ad6f15a80" }, "source": [ "### 1.3 Huggingface 로그인" ] }, { "cell_type": "code", "execution_count": 4, "id": "6aa22976-7bdf-479d-8c5c-8ab890be537f", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 145, "referenced_widgets": [ "d78fb194ea584519899fca75316935a5", "eeaef7c987eb495ebbd676284b84bc31", "b1bf1acfbeec4d27b2718f155a4f60ce", "0b92582f5a0e43c58919395a3790e986", "caac082c36f0432f85cb061f0e4cf203", "60f18007db4c40cabbfe41b2a5c90b9c", "2ebba5dee6e24c2f8a5e7e91111adf26", "76861f23355c422183168aa4a5660b9b", "9b4b9867824443179df0472b564e3a97", "931c6d74a0564062a6420af5298b1422", "2772fecde87842c2a80ef4678fcbd4c5", "7d06a99eb059484896f93c19610feb8d", "35a8980320c24fd59db1bf2a2df284f8", "ce789e054ffb4088be8cc036fcb86c48", "c548f62b7cf64fce8c020b573e2df7b8", "f95a739b5b7c4b948b2875c349eeb3cc", "5daa3d4168854c6ebf941921a5e4f936", "2e0908343d34433297e9e2bbc2436fe9", "4cdeb606d9464690b8f4e5a1441410dc", "dd6eca03bb9d4abd99ee35abfdd19a94", "8ba757205acd45acaddd823734082247", "d14dbf68e27944f6bdd24e7782d4ef29", "347b4975b84d428b937a4dc68597193c", "3c6879dbb4ed48a8977253bd04b9eea0", "d2120616526743d6920a2c29013205b7", "2c7b2b17b2d14e10a72563e312057c8d", "d1d6fd83db1d44578d0ea4ef8e9e0a3c", "ac8a46db70004de8beec5005a05259d0", "fa4e79867ed34425a728767878347304", "3acecabb06cd41b5b5da81745ef5d4f4", "0e9990a90c75425b84e473a24e97c98a", "2826f483dd9c40f5bc47702b4d19a94f" ] }, "id": "6aa22976-7bdf-479d-8c5c-8ab890be537f", "outputId": "d92eff69-6d8a-419b-bc9b-505bc0bc3194" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "VBox(children=(HTML(value='
Step | \n", "Training Loss | \n", "
---|---|
100 | \n", "2.439600 | \n", "
200 | \n", "2.108500 | \n", "
300 | \n", "2.051300 | \n", "
400 | \n", "1.939200 | \n", "
500 | \n", "1.931500 | \n", "
600 | \n", "1.925600 | \n", "
700 | \n", "1.859700 | \n", "
800 | \n", "1.852200 | \n", "
900 | \n", "1.860200 | \n", "
1000 | \n", "1.864000 | \n", "
1100 | \n", "1.817400 | \n", "
1200 | \n", "1.776700 | \n", "
1300 | \n", "1.731000 | \n", "
1400 | \n", "1.721500 | \n", "
1500 | \n", "1.732700 | \n", "
1600 | \n", "1.779000 | \n", "
1700 | \n", "1.699800 | \n", "
1800 | \n", "1.664400 | \n", "
1900 | \n", "1.654700 | \n", "
2000 | \n", "1.695800 | \n", "
2100 | \n", "1.717600 | \n", "
2200 | \n", "1.697600 | \n", "
2300 | \n", "1.677300 | \n", "
2400 | \n", "1.627900 | \n", "
2500 | \n", "1.662800 | \n", "
2600 | \n", "1.630000 | \n", "
2700 | \n", "1.642800 | \n", "
2800 | \n", "1.673700 | \n", "
2900 | \n", "1.632200 | \n", "
3000 | \n", "1.619400 | \n", "
"
]
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"TrainOutput(global_step=3000, training_loss=1.7895308176676432, metrics={'train_runtime': 2534.1773, 'train_samples_per_second': 4.735, 'train_steps_per_second': 1.184, 'total_flos': 6.742903609432474e+16, 'train_loss': 1.7895308176676432, 'epoch': 0.54})"
]
},
"metadata": {},
"execution_count": 21
}
],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "dca74e51-15ec-403a-90f1-4b7eeb2c723b",
"metadata": {
"id": "dca74e51-15ec-403a-90f1-4b7eeb2c723b"
},
"source": [
"### 4.4 Finetuned Model 저장"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "f2bba87d-d95c-4a57-9eb1-c02d81ad7bfb",
"metadata": {
"id": "f2bba87d-d95c-4a57-9eb1-c02d81ad7bfb",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "034a5c88-6651-467e-d2da-f23c796f6b57"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"ADAPTER_MODEL = \"lora_adapter\"\n",
"\n",
"trainer.model.save_pretrained(ADAPTER_MODEL)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "6a9fcda0-1d7a-4443-9b1c-7d45490daafb",
"metadata": {
"id": "6a9fcda0-1d7a-4443-9b1c-7d45490daafb",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f77eea88-1a14-4b94-caf8-2a2376fa0fa7"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"adapter_model.safetensors\t Size: 29450584 bytes\n",
"README.md\t Size: 5091 bytes\n",
"adapter_config.json\t Size: 689 bytes\n"
]
}
],
"source": [
"# !ls -alh lora_adapter\n",
"import os\n",
"\n",
"# 디렉터리 내용 확인\n",
"def list_directory_contents(directory):\n",
" try:\n",
" with os.scandir(directory) as entries:\n",
" for entry in entries:\n",
" info = entry.stat()\n",
" print(f\"{entry.name}\\t Size: {info.st_size} bytes\")\n",
" except FileNotFoundError:\n",
" print(f\"Directory '{directory}' not found.\")\n",
"\n",
"ADAPTER_MODEL = \"lora_adapter\"\n",
"list_directory_contents(ADAPTER_MODEL)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "a9a2a6d7-ece4-472a-981f-fb6599d1d307",
"metadata": {
"id": "a9a2a6d7-ece4-472a-981f-fb6599d1d307",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 49,
"referenced_widgets": [
"d83f27176aac4ce1a0ce373c0521f448",
"3432e44ff4e84e8db6439d46a299b1dd",
"74fb934e195b4535be099894f1d1a016",
"973c9e71beb045f0b30739a3f4da06fd",
"87e66fe990c841b0af70e3072b6d531f",
"10c474866a3c46989753a3b594873ff4",
"430d8d46284a422a887de79a179d2ffd",
"b7fc7f2d31b54b8f8341abb83a1f6821",
"856f1c4dae8b4566bc09679c40148e0a",
"76c2cc35d7934c1c8d69e3d62cb32b1f",
"72e2e556c0854cbab198ae1d2a527dab"
]
},
"outputId": "df2c540e-4ff1-48c7-ded8-c436f78d16e1"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "d83f27176aac4ce1a0ce373c0521f448"
}
},
"metadata": {}
}
],
"source": [
"model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, device_map='auto', torch_dtype=torch.float16)\n",
"model = PeftModel.from_pretrained(model, ADAPTER_MODEL, device_map='auto', torch_dtype=torch.float16)\n",
"\n",
"model = model.merge_and_unload()\n",
"model.save_pretrained('gemma-2b-it-sum-ko')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "1a764bbc-069d-400c-bca4-09e799bf0fb0",
"metadata": {
"id": "1a764bbc-069d-400c-bca4-09e799bf0fb0",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "374635a5-0e3c-4c2f-fba6-5ee48f082621"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"config.json\t Size: 662 bytes\n",
"model.safetensors.index.json\t Size: 13489 bytes\n",
"generation_config.json\t Size: 132 bytes\n",
"model-00002-of-00002.safetensors\t Size: 67121600 bytes\n",
"model-00001-of-00002.safetensors\t Size: 4945242104 bytes\n"
]
}
],
"source": [
"# !ls -alh ./gemma-2b-it-sum-ko\n",
"import os\n",
"\n",
"# 디렉터리 내용 확인 함수\n",
"def list_directory_contents(directory):\n",
" try:\n",
" with os.scandir(directory) as entries:\n",
" for entry in entries:\n",
" info = entry.stat()\n",
" print(f\"{entry.name}\\t Size: {info.st_size} bytes\")\n",
" except FileNotFoundError:\n",
" print(f\"Directory '{directory}' not found.\")\n",
"\n",
"# 디렉터리 내용 확인\n",
"directory = \"./gemma-2b-it-sum-ko\"\n",
"list_directory_contents(directory)"
]
},
{
"cell_type": "markdown",
"id": "84f2c237-71f4-47c2-bad4-181dadb6cc98",
"metadata": {
"id": "84f2c237-71f4-47c2-bad4-181dadb6cc98"
},
"source": [
"# 5. Gemma 한국어 요약 모델 추론"
]
},
{
"cell_type": "markdown",
"id": "8587dfc7-cf7c-4072-a8f7-6ceb1e90a532",
"metadata": {
"id": "8587dfc7-cf7c-4072-a8f7-6ceb1e90a532"
},
"source": [
"#### 주의: 마찬가지로 Colab GPU 메모리 한계로 학습 시 사용했던 메모리를 비워 줘야 파인튜닝을 진행 할 수 있습니다.
notebook 런타임 세션을 재시작 한 후 1번과 2번의 2.1 항목까지 다시 실행하여 로드 한 후 아래 과정을 진행합니다"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "906ed4dd-270f-4000-84de-ede6885c0be5",
"metadata": {
"id": "906ed4dd-270f-4000-84de-ede6885c0be5",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "9fc93e12-c20b-4df0-f642-10d8ff4ef737"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Sat Jul 13 08:11:51 2024 \n",
"+---------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 535.104.05 Driver Version: 535.104.05 CUDA Version: 12.2 |\n",
"|-----------------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+======================+======================|\n",
"| 0 NVIDIA A100-SXM4-40GB Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 39C P0 53W / 400W | 18449MiB / 40960MiB | 0% Default |\n",
"| | | Disabled |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
" \n",
"+---------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=======================================================================================|\n",
"+---------------------------------------------------------------------------------------+\n",
"\n"
]
}
],
"source": [
"# !nvidia-smi\n",
"import subprocess\n",
"\n",
"# GPU 상태 확인\n",
"def check_gpu():\n",
" try:\n",
" result = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE)\n",
" print(result.stdout.decode('utf-8'))\n",
" except FileNotFoundError:\n",
" print(\"nvidia-smi not found, ensure you have a GPU instance.\")\n",
"\n",
"check_gpu()"
]
},
{
"cell_type": "markdown",
"id": "78399236-63b5-41af-9cee-a7233e23a9db",
"metadata": {
"id": "78399236-63b5-41af-9cee-a7233e23a9db"
},
"source": [
"### 5.1 Fine-tuned 모델 로드"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "76d5ba97-91ca-48c3-b9a2-ba9bea6d7b09",
"metadata": {
"id": "76d5ba97-91ca-48c3-b9a2-ba9bea6d7b09",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 104,
"referenced_widgets": [
"1fae26b44ce144c1afe26fc6d87bc4c9",
"cdb6b8a160974912b7f5646b989d2142",
"cc3f97a2e623494cad8591f63adfde90",
"7f2c674bef754cf6ad867efec44abae4",
"579bb488cea94912a128a53c6103de8f",
"3faa57373d3648b7b30405dfded60a5b",
"5b392e95468441c5b1f7d7696f92cbcb",
"8dc06b4f1b9741c68bf63624476a09f1",
"b1b3817be3e146189548aafd8b908d5f",
"5d54bb59f5f148209ca0c20078f21262",
"3db996ed5cae4350b9115b21a6c040e6"
]
},
"outputId": "cd058a27-4a2d-4b1d-d68e-9b7a41a2e4d3"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "1fae26b44ce144c1afe26fc6d87bc4c9"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"BASE_MODEL = \"google/gemma-2b-it\"\n",
"FINETUNE_MODEL = \"./gemma-2b-it-sum-ko\"\n",
"\n",
"finetune_model = AutoModelForCausalLM.from_pretrained(FINETUNE_MODEL, device_map={\"\":0})\n",
"tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, add_special_tokens=True)"
]
},
{
"cell_type": "markdown",
"id": "5c34718c-ce52-4d68-ac8c-c18b6483b15b",
"metadata": {
"id": "5c34718c-ce52-4d68-ac8c-c18b6483b15b"
},
"source": [
"### 5.2 Fine-tuned 모델 추론"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "a0f0fc82-abaf-49df-9254-7ccee2e74d96",
"metadata": {
"scrolled": true,
"id": "a0f0fc82-abaf-49df-9254-7ccee2e74d96"
},
"outputs": [],
"source": [
"pipe_finetuned = pipeline(\"text-generation\", model=finetune_model, tokenizer=tokenizer, max_new_tokens=512)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "2f915638-d859-446f-bc78-070650421ece",
"metadata": {
"id": "2f915638-d859-446f-bc78-070650421ece"
},
"outputs": [],
"source": [
"doc = dataset['test']['document'][10]"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "396788e7-4b80-46d7-980f-38fcb892a94f",
"metadata": {
"id": "396788e7-4b80-46d7-980f-38fcb892a94f"
},
"outputs": [],
"source": [
"messages = [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"다음 글을 요약해주세요:\\n\\n{}\".format(doc)\n",
" }\n",
"]\n",
"prompt = pipe_finetuned.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "03f1f711-0ba7-4087-8317-b0e7f4246aee",
"metadata": {
"id": "03f1f711-0ba7-4087-8317-b0e7f4246aee",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "836b4040-4ebd-44e9-ce1a-69c9c2cb1573"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"한국동서발전은 다음달 31일까지 울산시민과 함께하는 생활 속 걷기 챌린지 탄소중립 건강걷기 누비GO 초록발자국 챌린지 를 진행한다고 1일 밝혔으며 울산숲사랑운동과 함께 하는 이번 챌린지는 코로나19로 지친 시민들이 일상 속 걷기 운동을 통해 건강을 증진하고 자가용 대신 대중교통을 이용해 탄소중립 사회를 실현하기 위해 마련됐다.\n"
]
}
],
"source": [
"outputs = pipe_finetuned(\n",
" prompt,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" top_k=50,\n",
" top_p=0.95,\n",
" add_special_tokens=True\n",
")\n",
"print(outputs[0][\"generated_text\"][len(prompt):])"
]
},
{
"cell_type": "code",
"source": [
"# 복사할 폴더 경로 설정\n",
"folders_to_copy = ['gemma-2b-it-sum-ko', 'lora_adapter', 'outputs']\n",
"destination_base_folder = '/content/drive/My Drive/gemma-2b-finetuning'\n",
"\n",
"# 폴더 복사 함수\n",
"def copy_folder(src, dst):\n",
" try:\n",
" shutil.copytree(src, dst)\n",
" print(f\"Folder copied to Google Drive at: {dst}\")\n",
" except Exception as e:\n",
" print(f\"Error: {e}\")\n",
"\n",
"# 각 폴더를 복사\n",
"os.makedirs(destination_base_folder, exist_ok=True)\n",
"\n",
"for folder in folders_to_copy:\n",
" source_folder = f'./{folder}'\n",
" destination_folder = f'{destination_base_folder}/{folder}'\n",
" copy_folder(source_folder, destination_folder)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OsYFrH83KmaG",
"outputId": "95152ad2-1e3e-4099-e45c-71f144d193b0"
},
"id": "OsYFrH83KmaG",
"execution_count": 39,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n",
"Folder copied to Google Drive at: /content/drive/My Drive/gemma-2b-finetuning/gemma-2b-it-sum-ko\n",
"Folder copied to Google Drive at: /content/drive/My Drive/gemma-2b-finetuning/lora_adapter\n",
"Folder copied to Google Drive at: /content/drive/My Drive/gemma-2b-finetuning/outputs\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "-jIEbIKeLUdE"
},
"id": "-jIEbIKeLUdE",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "A100"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"d78fb194ea584519899fca75316935a5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "VBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "VBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "VBoxView",
"box_style": "",
"children": [
"IPY_MODEL_8ba757205acd45acaddd823734082247",
"IPY_MODEL_d14dbf68e27944f6bdd24e7782d4ef29",
"IPY_MODEL_347b4975b84d428b937a4dc68597193c",
"IPY_MODEL_3c6879dbb4ed48a8977253bd04b9eea0"
],
"layout": "IPY_MODEL_2ebba5dee6e24c2f8a5e7e91111adf26"
}
},
"eeaef7c987eb495ebbd676284b84bc31": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_76861f23355c422183168aa4a5660b9b",
"placeholder": "",
"style": "IPY_MODEL_9b4b9867824443179df0472b564e3a97",
"value": "