{ "cells": [ { "cell_type": "markdown", "id": "6eb94b72", "metadata": {}, "source": [ "#### Set environment variables in [.env](.env) for LLM API calling" ] }, { "cell_type": "markdown", "id": "388020c6", "metadata": {}, "source": [ "### Import Dependencies" ] }, { "cell_type": "code", "execution_count": null, "id": "11efa138", "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.insert(0, \"../../\")\n", "import promptwizard\n", "from promptwizard.glue.promptopt.instantiate import GluePromptOpt\n", "from promptwizard.glue.promptopt.techniques.common_logic import DatasetSpecificProcessing\n", "from promptwizard.glue.common.utils.file import save_jsonlist\n", "from typing import Any\n", "from tqdm import tqdm\n", "from re import compile, findall\n", "import os\n", "from datasets import load_dataset\n", "\n", "from dotenv import load_dotenv\n", "load_dotenv(override = True)" ] }, { "cell_type": "markdown", "id": "beb14821", "metadata": {}, "source": [ "### Create a dataset specific class and define the required functions " ] }, { "cell_type": "code", "execution_count": 2, "id": "5f325d33", "metadata": {}, "outputs": [], "source": [ "class GSM8k(DatasetSpecificProcessing):\n", "\n", " def dataset_to_jsonl(self, dataset_jsonl: str, **kwargs: Any) -> None:\n", " def extract_answer_from_output(completion):\n", " # Your functions for metrics and prompt building\n", " ans_re = compile(r\"#### (\\-?[0-9\\.\\,]+)\")\n", " self.INVALID_ANS = \"[invalid]\"\n", "\n", " match = ans_re.search(completion)\n", " if match:\n", " match_str = match.group(1).strip()\n", " match_str = match_str.replace(\",\", \"\")\n", " return match_str\n", " else:\n", " return self.INVALID_ANS\n", "\n", " examples_set = []\n", "\n", " for _, sample in tqdm(enumerate(kwargs[\"dataset\"]), desc=\"Evaluating samples\"):\n", " example = {\n", " DatasetSpecificProcessing.QUESTION_LITERAL: sample['question'],\n", " DatasetSpecificProcessing.ANSWER_WITH_REASON_LITERAL: sample['answer'],\n", " DatasetSpecificProcessing.FINAL_ANSWER_LITERAL: extract_answer_from_output(sample[\"answer\"])\n", " }\n", " examples_set.append(example)\n", "\n", " save_jsonlist(dataset_jsonl, examples_set, \"w\")\n", "\n", " def extract_final_answer(self, answer: str):\n", " \n", " if not answer:\n", " return self.INVALID_ANS\n", "\n", " model_pred = answer.lower()\n", " preds = model_pred.split(self.ANSWER_START.lower())\n", " answer_flag = True if len(preds) > 1 else False\n", "\n", " pred = preds[-1].replace(\",\", \"\")\n", " pred = [s for s in findall(r'-?\\d+\\.?\\d*', pred)]\n", "\n", " if len(pred) == 0:\n", " return self.INVALID_ANS\n", "\n", " if answer_flag:\n", " # choose the first element in list\n", " pred = pred[0]\n", " else:\n", " # choose the last element in list\n", " pred = pred[-1]\n", "\n", " # (For arithmetic tasks) if a word ends with period, it will be omitted ...\n", " if pred[-1] == \".\":\n", " pred = pred[:-1]\n", " return pred" ] }, { "cell_type": "code", "execution_count": 3, "id": "f384eb57", "metadata": {}, "outputs": [], "source": [ "gsm8k_processor = GSM8k()" ] }, { "cell_type": "markdown", "id": "11d2de75", "metadata": {}, "source": [ "### Load and save the dataset " ] }, { "cell_type": "code", "execution_count": null, "id": "976681bd-4f43-4dbc-947e-cdb94d4824f0", "metadata": {}, "outputs": [], "source": [ "if not os.path.exists(\"data\"):\n", " os.mkdir(\"data\")\n", " \n", "dataset = load_dataset(\"openai/gsm8k\", \"main\")\n", "num_samples = 0\n", "for dataset_type in ['train','test']:\n", " data_list = []\n", " for data in dataset[dataset_type]:\n", " data_list.append({\"question\": data['question'], \"answer\": data['answer']})\n", " if num_samples == 100 and dataset_type == 'train': # We sample only 100 train examples and use 25 out them for training randomly\n", " break\n", " num_samples += 1\n", " gsm8k_processor.dataset_to_jsonl(\"data/\"+ dataset_type+'.jsonl', dataset=data_list)" ] }, { "cell_type": "markdown", "id": "ac30e74f", "metadata": {}, "source": [ "### Set paths" ] }, { "cell_type": "code", "execution_count": 5, "id": "f43482f1-3e10-4cf7-8ea6-ff42c04067a6", "metadata": {}, "outputs": [], "source": [ "train_file_name = os.path.join(\"data\", \"train.jsonl\")\n", "test_file_name = os.path.join(\"data\", \"test.jsonl\")\n", "path_to_config = \"configs\"\n", "promptopt_config_path = os.path.join(path_to_config, \"promptopt_config.yaml\")\n", "setup_config_path = os.path.join(path_to_config, \"setup_config.yaml\")" ] }, { "cell_type": "markdown", "id": "3392594d", "metadata": {}, "source": [ "### Create an object for calling prompt optimization and inference functionalities" ] }, { "cell_type": "code", "execution_count": null, "id": "8af4246f-db32-4b37-a73a-f9e2e5125d09", "metadata": {}, "outputs": [], "source": [ "gp = GluePromptOpt(promptopt_config_path,\n", " setup_config_path,\n", " train_file_name,\n", " gsm8k_processor)" ] }, { "cell_type": "markdown", "id": "1784648c", "metadata": {}, "source": [ "### Call prompt optmization function\n", "1. ```use_examples``` can be used when there are training samples and a mixture of real and synthetic in-context examples are required in the final prompt. When set to ```False``` all the in-context examples will be real\n", "2. ```generate_synthetic_examples``` can be used when there are no training samples and we want to generate synthetic examples \n", "3. ```run_without_train_examples``` can be used when there are no training samples and in-context examples are not required in the final prompt " ] }, { "cell_type": "code", "execution_count": null, "id": "573c6151-2c03-45d9-9904-1724a1e20f1b", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Function call to generate optimal prompt and expert profile \n", "best_prompt, expert_profile = gp.get_best_prompt(use_examples=True,run_without_train_examples=False,generate_synthetic_examples=False)" ] }, { "cell_type": "markdown", "id": "1ee1aa99", "metadata": {}, "source": [ "### Save the optimized prompt and expert profile" ] }, { "cell_type": "code", "execution_count": null, "id": "34a716af-0d77-4c7d-b1c2-6438d66096ce", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import pickle \n", "\n", "if not os.path.exists(\"results\"):\n", " os.system(\"mkdir results\")\n", " \n", "with open(\"results/best_prompt.pkl\", 'wb') as f:\n", " pickle.dump(best_prompt, f)\n", "with open(\"results/expert_profile.pkl\", 'wb') as f:\n", " pickle.dump(expert_profile, f)\n", "\n", "print(f\"Best prompt: {best_prompt} \\nExpert profile: {expert_profile}\")" ] }, { "cell_type": "markdown", "id": "aac42eed", "metadata": {}, "source": [ "### Evaluate the optimized prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "c49b5711-82dd-4d18-8cd4-ee447cf8d74c", "metadata": { "scrolled": true }, "outputs": [], "source": [ "gp.EXPERT_PROFILE = expert_profile\n", "gp.BEST_PROMPT = best_prompt\n", "\n", "# Function call to evaluate the prompt\n", "accuracy = gp.evaluate(test_file_name)\n", "\n", "print(f\"Final Accuracy: {accuracy}\")" ] } ], "metadata": { "kernelspec": { "display_name": "general", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 }