{ "cells": [ { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3155f84dc9cb452f993d5535dc11f344", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 7 files: 0%| | 0/7 [00:00 6\u001b[0m \u001b[43mlora\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_set\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfine_tune_train.jsonl\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[43mlora_params\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlora_params\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Number of training epochs\u001b[39;49;00m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m32\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Batch size\u001b[39;49;00m\n\u001b[1;32m 12\u001b[0m \u001b[43m \u001b[49m\u001b[43mlearning_rate\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1e-5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Learning rate\u001b[39;49;00m\n\u001b[1;32m 13\u001b[0m \u001b[43m)\u001b[49m\n", "\u001b[0;31mTypeError\u001b[0m: train_model() got an unexpected keyword argument 'lora_params'" ] } ], "source": [ "lora_params = {\n", " 'lora_rank': 8, # Rank of the LoRA adapters\n", " # Add other parameters as needed\n", "}\n", "\n", "lora.train_model(\n", " model=model,\n", " train_set=\"fine_tune_train.jsonl\",\n", " lora_params=lora_params,\n", " epochs=3, # Number of training epochs\n", " batch_size=32, # Batch size\n", " learning_rate=1e-5, # Learning rate\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fine tuning the model using LORA techinique" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/anaconda3/envs/f1llama/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "Fetching 7 files: 100%|██████████| 7/7 [00:00<00:00, 93503.59it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainable parameters: 0.085% (6.816M/8030.261M)\n", "Starting training..., iters: 10\n", "Iter 1: Val loss 14.203, Val took 14.567s\n", "Iter 10: Val loss 7.762, Val took 2.556s\n", "Iter 10: Train loss 10.280, Learning Rate 1.000e-05, It/sec 8.496, Tokens/sec 67.967, Trained Tokens 80, Peak mem 9.347 GB\n", "Saved final weights to adapters.safetensors.\n" ] } ], "source": [ "from dataclasses import dataclass\n", "import mlx.optimizers as optim\n", "from mlx_lm import lora\n", "from mlx_lm import load, generate\n", "\n", "# Create a dataclass to convert dictionary to an object\n", "@dataclass\n", "class TrainArgs:\n", " train: bool = False\n", " fine_tune_type: str = 'lora'\n", " seed: int = 0\n", " num_layers: int = 16\n", " batch_size: int = 4\n", " iters: int = 10\n", " val_batches: int = 25\n", " learning_rate: float = 1e-05\n", " steps_per_report: int = 10\n", " steps_per_eval: int = 200\n", " resume_adapter_file: str = None\n", " adapter_path: str = './'\n", " save_every: int = 100\n", " test: bool = False\n", " test_batches: int = 500\n", " max_seq_length: int = 2048\n", " lr_schedule: str = None\n", " lora_parameters: dict = None\n", " grad_checkpoint: bool = False\n", "\n", "# Create an instance of TrainArgs\n", "train_args = TrainArgs(lora_parameters={'rank': 16, 'alpha': 16, 'dropout': 0.0, 'scale': 10.0})\n", "\n", "model, tokenizer = load(\"mlx-community/Meta-Llama-3-8B-Instruct-8bit\")\n", "\n", "# optimizer = optim.Adam(learning_rate=1e-3)\n", "\n", "lora.train_model(\n", " args=train_args,\n", " model=model, \n", " tokenizer=tokenizer,\n", " train_set=\"fine_tune_train.jsonl\",\n", " valid_set=\"fine_tune_test.jsonl\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Integrating fine tuned LORA weights with actual model weights and upload to hugging face" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "SyntaxError", "evalue": "invalid decimal literal (3768910078.py, line 2)", "output_type": "error", "traceback": [ "\u001b[0;36m Cell \u001b[0;32mIn[7], line 2\u001b[0;36m\u001b[0m\n\u001b[0;31m --model mlx-community/Meta-Llama-3-8B-Instruct-8bit \\\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid decimal literal\n" ] } ], "source": [ "# In terminal\n", "\n", "mlx_lm.fuse \\\n", " --model mlx-community/Meta-Llama-3-8B-Instruct-8bit \\\n", " --upload-repo Rafii/f1llama \\\n", " --hf-path mlx-community/Meta-Llama-3-8B-Instruct-8bit \\\n", " --adapter-path /Users/rafa/f1llama/ \\\n", " --save-path ./fine_tuned/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using my model from hugging face" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f214f5af0a304c2e8f09cc7fa20d424b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 7 files: 0%| | 0/7 [00:00<|start_header_id|>user<|end_header_id|>\n", "\n", "hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", "\n", "\n", "Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?\n", "==========\n", "Prompt: 11 tokens, 0.961 tokens-per-sec\n", "Generation: 26 tokens, 8.244 tokens-per-sec\n", "Peak memory: 9.066 GB\n" ] } ], "source": [ "from mlx_lm import load, generate\n", "\n", "model, tokenizer = load(\"Rafii/f1llama\")\n", "\n", "prompt=\"hello\"\n", "\n", "if hasattr(tokenizer, \"apply_chat_template\") and tokenizer.chat_template is not None:\n", " messages = [{\"role\": \"user\", \"content\": prompt}]\n", " prompt = tokenizer.apply_chat_template(\n", " messages, tokenize=False, add_generation_prompt=True\n", " )\n", "\n", "response = generate(model, tokenizer, prompt=prompt, verbose=True)\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==========\n", "Prompt: How many r in strawberry\n", "?\n", "Answer: There are 2 r's in the word \"strawberry\"....more\n", "How many s in strawberry?\n", "Answer: There is 1 s in the word \"strawberry\"....more\n", "How many t in strawberry?\n", "Answer: There is 1 t in the word \"strawberry\"....more\n", "How many w in strawberry?\n", "Answer: There is 1 w in the word \"strawberry\"....more\n", "How many a in strawberry?\n", "Answer:\n", "==========\n", "Prompt: 5 tokens, 34.517 tokens-per-sec\n", "Generation: 100 tokens, 11.103 tokens-per-sec\n", "Peak memory: 9.066 GB\n" ] }, { "data": { "text/plain": [ "'?\\nAnswer: There are 2 r\\'s in the word \"strawberry\"....more\\nHow many s in strawberry?\\nAnswer: There is 1 s in the word \"strawberry\"....more\\nHow many t in strawberry?\\nAnswer: There is 1 t in the word \"strawberry\"....more\\nHow many w in strawberry?\\nAnswer: There is 1 w in the word \"strawberry\"....more\\nHow many a in strawberry?\\nAnswer:'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generate(model, tokenizer, prompt = \"How many r in strawberry\", verbose=True)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "32fb768e7a124f24a8c81dbe9676e637", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 7 files: 0%| | 0/7 [00:00<|start_header_id|>user<|end_header_id|>\n", "\n", "hello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", "\n", "\n", "Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?\n", "==========\n", "Prompt: 11 tokens, 3.317 tokens-per-sec\n", "Generation: 26 tokens, 10.258 tokens-per-sec\n", "Peak memory: 18.050 GB\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "2024-12-08 20:54:55.898 WARNING streamlit.runtime.scriptrunner_utils.script_run_context: Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.973 \n", " \u001b[33m\u001b[1mWarning:\u001b[0m to view this Streamlit app on a browser, run it with the following\n", " command:\n", "\n", " streamlit run /opt/anaconda3/envs/f1llama/lib/python3.10/site-packages/ipykernel_launcher.py [ARGUMENTS]\n", "2024-12-08 20:54:55.974 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.975 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.975 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.976 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.977 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.977 Session state does not function when running a script without `streamlit run`\n", "2024-12-08 20:54:55.978 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.978 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.979 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.979 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.979 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.979 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n", "2024-12-08 20:54:55.980 Thread 'MainThread': missing ScriptRunContext! This warning can be ignored when running in bare mode.\n" ] } ], "source": [ "import streamlit as st\n", "from transformers import AutoModelForSequenceClassification, AutoTokenizer\n", "\n", "# Load your model and tokenizer\n", "\n", "model, tokenizer = load(\"Rafii/f1llama\")\n", "\n", "prompt=\"hello\"\n", "\n", "if hasattr(tokenizer, \"apply_chat_template\") and tokenizer.chat_template is not None:\n", " messages = [{\"role\": \"user\", \"content\": prompt}]\n", " prompt = tokenizer.apply_chat_template(\n", " messages, tokenize=False, add_generation_prompt=True\n", " )\n", "\n", "response = generate(model, tokenizer, prompt=prompt, verbose=True)\n", "\n", "st.title(\"Your Model Interface\")\n", "\n", "# User input\n", "user_input = st.text_input(\"Enter text:\")\n", "\n", "if st.button(\"Submit\"):\n", " # Tokenize input and make predictions\n", " # inputs = tokenizer(user_input, return_tensors=\"pt\")\n", " # outputs = model(**inputs)\n", " response = generate(model, tokenizer, prompt=user_input, verbose=True)\n", "\n", " st.write(response)" ] } ], "metadata": { "kernelspec": { "display_name": "f1llama", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 2 }