{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4cf3ca7c-2c43-495b-a1ee-24c770f0ad1e",
   "metadata": {},
   "source": [
    "### Simple OpenAI agent with tool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9d74b11-4049-4e3c-839e-7d13d7c0dadc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import sqlite3\n",
    "from typing import Sequence, List\n",
    "from pydantic import BaseModel, Field\n",
    "\n",
    "from llama_index.core.llms import ChatMessage\n",
    "from llama_index.core.tools import BaseTool, FunctionTool\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.agent.openai import OpenAIAgent\n",
    "\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33da3ba3-7d5d-4528-8e31-cabf85a88886",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b27b2c3b-1876-4a20-8fb9-cddd6df51ef3",
   "metadata": {},
   "outputs": [],
   "source": [
    "db_path = \"../database/mock_qna.sqlite\"\n",
    "con = sqlite3.connect(db_path)\n",
    "cur = con.cursor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b69adba2-9f98-460b-b0a3-e759d6ac1b88",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39ca840b-0ec7-4279-9654-e4ddfda6137f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define sample Tool\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiply two integers and returns the result integer\"\"\"\n",
    "    return a * b\n",
    "\n",
    "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n",
    "\n",
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Add two integers and returns the result integer\"\"\"\n",
    "    return a + b\n",
    "\n",
    "add_tool = FunctionTool.from_defaults(fn=add)\n",
    "\n",
    "class Song(BaseModel):\n",
    "    \"\"\"A song with name and artist\"\"\"\n",
    "\n",
    "    name: str\n",
    "    artist: str\n",
    "\n",
    "song_fn = FunctionTool.from_defaults(fn=Song)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14b64612-1320-48b2-b5ff-91cde659cbf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "class QnA_Model(BaseModel):\n",
    "    chapter_n: str = Field(..., \n",
    "                           pattern=r'^Chapter_\\d*$',\n",
    "                           description=(\n",
    "                                \"which chapter to extract, the format of this function argumet\"\n",
    "                                \"is with `Chapter_` as prefix concatenated with chapter number\"\n",
    "                                \"in integer. For example, `Chapter_2`, `Chapter_10`.\")\n",
    "                          )\n",
    "\n",
    "description = \"\"\"\n",
    "      Use this tool to extract the chapter information from the body of the input text, \n",
    "      the format looks as follow:\n",
    "        The output should be in the format with `Chapter_` as prefix.\n",
    "        Example 1: `Chapter_1` for first chapter\n",
    "        Example 2: For chapter 12 of the textbook, you should return `Chapter_12`\n",
    "        Example 3: `Chapter_5` for fifth chapter\n",
    "        Thereafter, the chapter_n argument will be passed to the function for Q&A question retrieval.\n",
    "\"\"\"\n",
    "\n",
    "def get_qna_question(chapter_n: str):\n",
    "    \"\"\"\n",
    "      Use this tool to extract the chapter information from the body of the input text, \n",
    "      the format looks as follow:\n",
    "        The output should be in the format with `Chapter_` as prefix.\n",
    "        Example 1: `Chapter_1` for first chapter\n",
    "        Example 2: For chapter 12 of the textbook, you should return `Chapter_12`\n",
    "        Example 3: `Chapter_5` for fifth chapter\n",
    "        Thereafter, the chapter_n argument will be passed to the function for Q&A question retrieval.\n",
    "    \"\"\"\n",
    "    sql_string = f\"\"\"SELECT id, question, option_1, option_2, option_3, option_4, correct_answer\n",
    "                     FROM qna_tbl\n",
    "                     WHERE chapter='{chapter_n}'\n",
    "                  \"\"\"\n",
    "    res = cur.execute(sql_string)\n",
    "    result = res.fetchone()\n",
    "\n",
    "    id       = result[0]\n",
    "    question = result[1]\n",
    "    option_1 = result[2]\n",
    "    option_2 = result[3]\n",
    "    option_3 = result[4]\n",
    "    option_4 = result[5]\n",
    "    c_answer = result[6]\n",
    "\n",
    "    qna_str  = \"Question: \\n\" + \\\n",
    "               \"========= \\n\" + \\\n",
    "                question.replace(\"\\\\n\", \"\\n\") + \"\\n\" + \\\n",
    "               \"A) \" + option_1 + \"\\n\" + \\\n",
    "               \"B) \" + option_2 + \"\\n\" + \\\n",
    "               \"C) \" + option_3 + \"\\n\" + \\\n",
    "               \"D) \" + option_4\n",
    "    \n",
    "    # return id, qna_str, c_answer\n",
    "    return qna_str\n",
    "\n",
    "get_qna_question_tool = FunctionTool.from_defaults(\n",
    "                            fn=get_qna_question,\n",
    "                            name=\"Extract_Question\",\n",
    "                            description=description,\n",
    "                            fn_schema=QnA_Model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4df3e9f2-4a32-4449-b203-929dff9e7963",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbeeea36-0bb0-4edb-9b8c-adb7c64c4cd3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize openai agent\n",
    "llm = OpenAI(model=\"gpt-3.5-turbo-0613\")\n",
    "agent = OpenAIAgent.from_tools([multiply_tool, \n",
    "                                add_tool, \n",
    "                                song_fn, \n",
    "                                get_qna_question_tool], llm=llm, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0edafe7d-a835-4882-bd7d-1717a4cad462",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65158ede-b99f-477d-9d17-3be40e57a629",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = agent.chat(message=\"3 x 2 equals?\", tool_choice=\"auto\")\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d74b0e5c-47a2-4de4-acd2-d39a94941f2d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9787ed4-46a2-46aa-80e6-b317d9280b9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = agent.chat(message=\"3 plus 2 equals?\", tool_choice=\"auto\")\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd3358c6-e0e4-4354-8a4e-97d70254f648",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16e1db2c-dca3-4dc6-9cc5-c10644d5927c",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = agent.chat(message=\"give me the lyrics of taylor swift's `you belong with me`\", tool_choice=\"auto\")\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97a021f4-4701-4914-9ab8-0683b396f096",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f0b352d-8510-4b2a-a495-9f2e1fbfcddb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# res_stream = agent.stream_chat(message=\"3 x 2 equals?\")\n",
    "# for r in res_stream.response_gen:\n",
    "#     print(r, end=\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0a6cd4-f204-4997-bdfb-cb9b5a9e1266",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "96c978e6-62e2-46e6-ae63-76841487f618",
   "metadata": {},
   "source": [
    "### OpenAI agent with embeddings, and function calling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f0f5f02-c8e9-43a9-853d-12bb3c19dbe8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import chromadb\n",
    "\n",
    "from llama_index.core import (\n",
    "    SimpleDirectoryReader,\n",
    "    VectorStoreIndex,\n",
    "    ServiceContext,\n",
    "    StorageContext,\n",
    "    load_index_from_storage,\n",
    ")\n",
    "from llama_index.core.memory import ChatMemoryBuffer\n",
    "from llama_index.vector_stores.chroma.base import ChromaVectorStore\n",
    "from llama_index.core.tools import QueryEngineTool, ToolMetadata\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.agent.openai import OpenAIAgent\n",
    "from llama_index.core.tools import BaseTool, FunctionTool\n",
    "from llama_index.core import Settings\n",
    "\n",
    "from pydantic import BaseModel\n",
    "import nest_asyncio\n",
    "\n",
    "nest_asyncio.apply()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54855aa4-dcad-404e-843f-c96d61046df3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4edb89f6-bb2f-46ff-8807-dfb03115fcd5",
   "metadata": {},
   "source": [
    "#### Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61ad7369-8fd4-434f-b687-0c649940bda1",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_model = \"gpt-3.5-turbo-1106\"\n",
    "temperature = 0.0\n",
    "\n",
    "persisted_vector_db = \"../models/chroma_db\"\n",
    "input_files = [\"../raw_documents/HI_Knowledge_Base.pdf\",\n",
    "               \"../raw_documents/HI Chapter Summary Version 1.3.pdf\",\n",
    "               \"../raw_documents/qna.txt\"]\n",
    "fine_tuned_path = \"local:../models/fine-tuned-embeddings\"\n",
    "system_content = (\"You are a helpful study assistant. \"\n",
    "                  \"You do not respond as 'User' or pretend to be 'User'. \"\n",
    "                  \"You only respond once as 'Assistant'.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a18117f5-48a7-4e81-9b46-541f382caf9e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "3210c837-9b40-4cd9-bb00-ead559deff6f",
   "metadata": {},
   "source": [
    "#### Load vector store"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9dfba0c-f27d-49d1-86c5-a1d95c11b844",
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(persisted_vector_db):\n",
    "    db = chromadb.PersistentClient(path=persisted_vector_db)\n",
    "    chroma_collection = db.get_or_create_collection(\"quickstart\")\n",
    "    \n",
    "    # assign chroma as the vector_store to the context\n",
    "    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)\n",
    "    storage_context = StorageContext.from_defaults(vector_store=vector_store)\n",
    "\n",
    "else:\n",
    "    documents = SimpleDirectoryReader(input_files=input_files).load_data()\n",
    "    document = Document(text=\"\\n\\n\".join([doc.text for doc in documents]))\n",
    "    \n",
    "    # initialize client, setting path to save data\n",
    "    db = chromadb.PersistentClient(path=persisted_vector_db)\n",
    "    \n",
    "    # create collection\n",
    "    chroma_collection = db.get_or_create_collection(\"quickstart\")\n",
    "\n",
    "    # assign chroma as the vector_store to the context\n",
    "    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)\n",
    "    storage_context = StorageContext.from_defaults(vector_store=vector_store)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "693c9808-efbe-47a6-a49c-7143c63d13e5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6bfae0b-7c97-4c2b-9996-f5e3ecf7a992",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define sample Tool\n",
    "def multiply(a: int, b: int) -> int:\n",
    "    \"\"\"Multiply two integers and returns the result integer\"\"\"\n",
    "    return a * b\n",
    "\n",
    "def add(a: int, b: int) -> int:\n",
    "    \"\"\"Add two integers and returns the result integer\"\"\"\n",
    "    return a + b\n",
    "\n",
    "class Song(BaseModel):\n",
    "    \"\"\"A song with name and artist\"\"\"\n",
    "\n",
    "    name: str\n",
    "    artist: str\n",
    "\n",
    "add_tool = FunctionTool.from_defaults(fn=add)\n",
    "multiply_tool = FunctionTool.from_defaults(fn=multiply)\n",
    "song_fn = FunctionTool.from_defaults(fn=Song)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16a80b2e-8e5f-462a-8616-042afe18be3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = OpenAI(model=selected_model, temperature=temperature)\n",
    "\n",
    "Settings.llm = llm\n",
    "Settings.embed_model = fine_tuned_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95d3a420-1ee4-45bd-a18b-b398d9531db4",
   "metadata": {},
   "outputs": [],
   "source": [
    "index = VectorStoreIndex.from_vector_store(\n",
    "    vector_store=vector_store,\n",
    "    storage_context=storage_context\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eac6d76d-059b-40e3-b67f-c736f1ce6baa",
   "metadata": {},
   "outputs": [],
   "source": [
    "memory = ChatMemoryBuffer.from_defaults(token_limit=15000)\n",
    "\n",
    "hi_engine = index.as_query_engine(\n",
    "                memory=memory,\n",
    "                system_prompt=system_content,\n",
    "                similarity_top_k=3,\n",
    "                streaming=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18e38a8f-9b51-4675-a1d5-8aaa6c21694c",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = hi_engine.query(\"what is the healthcare philosophy in singapore\")\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70dae6f7-682e-42d6-be59-3b807c10482c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eb5df65-c926-4b22-8071-449d645b339f",
   "metadata": {},
   "outputs": [],
   "source": [
    "hi_query_tool = QueryEngineTool.from_defaults(\n",
    "                    query_engine=hi_engine,\n",
    "                    name=\"vector_tool\",\n",
    "                    description=(\n",
    "                        \"Provides information about Health Insurance landscape in Singapore. \"\n",
    "                        \"Use a detailed plain text question as input to the tool.\"\n",
    "                    )\n",
    ")\n",
    "\n",
    "# hi_query_tool = QueryEngineTool(\n",
    "#                     query_engine=hi_engine,\n",
    "#                     metadata=ToolMetadata(\n",
    "#                         name=\"health_insurance_mentor\",\n",
    "#                         description=(\n",
    "#                             \"Provides information about Health Insurance landscape in Singapore. \"\n",
    "#                             \"Use a detailed plain text question as input to the tool.\"\n",
    "#                         )\n",
    "#                     )\n",
    "#                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5822b1d-32ef-4b68-8629-a727ff51cd0a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a1235da-a379-4055-8bcf-4b21c91c9fb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = OpenAIAgent.from_tools([multiply_tool, \n",
    "                                add_tool, \n",
    "                                hi_query_tool, \n",
    "                                get_qna_question_tool], llm=llm, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05b65cbd-d15c-4909-b383-50b13f64e535",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63332a44-9441-4f49-85a2-934e2c55a362",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = agent.chat(\"what is the healthcare philosophy in singapore\", tool_choice=\"auto\")\n",
    "res.response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de387041-706c-4be8-ab31-fe8bd8b16bc1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb8a8676-c070-4652-8c00-436be3135c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = agent.chat(message=\"I am interested in 4th chapter, can you test my understanding of this chapter?\",\n",
    "                 tool_choice=\"auto\")\n",
    "res.response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81709cbf-9a5e-482f-ae6a-ba361b8219dc",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adf26268-e40a-4ebd-a737-6b203ddc4444",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = agent.stream_chat(\"what is the healthcare philosophy in singapore\", tool_choice=\"auto\")\n",
    "for r in res.response_gen:\n",
    "    print(r, end=\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4f2df51-553f-493d-874c-662ecb499e36",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4bf66e59-a394-4e6b-b7d5-af6b1612c97b",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = agent.stream_chat(message=\"I am interested in 4th chapter, can you test my understanding of this chapter?\",\n",
    "                        tool_choice=\"auto\")\n",
    "for r in res.response_gen:\n",
    "    print(r, end=\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "540c0f71-048a-4a64-9818-e2b1cffc0db7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fbeabf28-30f9-4d7f-a4b9-21cd08a9b128",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = agent.stream_chat(\"what is the result of 328123 + 2891230\", tool_choice=\"auto\")\n",
    "for r in res.response_gen:\n",
    "    print(r, end=\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19b7e12c-0729-4181-acce-53a3a95b67b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "328123 + 2891230"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bca4c0b2-5165-4943-af1f-d3168ee88fcd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}