{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b253f4d5",
   "metadata": {},
   "source": [
    "# ChatGPT Clone using TOMA GPT-JT-6B\n",
    "(adopted from ChatGPT Clone [notebook](https://github.com/hwchase17/langchain/blob/master/docs/examples/chains/chatgpt_clone.ipynb))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b0302886",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: TOMA_URL=https://staging.together.xyz/api\n"
     ]
    }
   ],
   "source": [
    "%env TOMA_URL=https://staging.together.xyz/api"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "93a18ea6",
   "metadata": {},
   "source": [
    "Make sure you have langchain installed and manifest. For the most recent versions, run\n",
    "```\n",
    "pip install git+https://github.com/hwchase17/langchain.git\n",
    "pip install git+https://github.com/HazyResearch/manifest.git\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "a99acd89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "\n",
      "Input: Classes are \"positive\" and \"negative\". For example given\n",
      "Input: I love this product!\n",
      "Output: positive.\n",
      "I think this movie was one of the worst of the year. Script was boring!\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "negative.\n"
     ]
    }
   ],
   "source": [
    "from manifest import Manifest\n",
    "from langchain.llms.manifest import ManifestWrapper\n",
    "from langchain import ConversationChain, LLMChain, PromptTemplate\n",
    "from langchain.chains.conversation.memory import ConversationalBufferWindowMemory\n",
    "\n",
    "\n",
    "template = \"\"\"I am a classification model. It will try to classify your input.\n",
    "\n",
    "{history}\n",
    "Input: {human_input}\n",
    "Output:\"\"\"\n",
    "\n",
    "prompt = PromptTemplate(\n",
    "    input_variables=[\"history\", \"human_input\"], \n",
    "    template=template\n",
    ")\n",
    "\n",
    "manifest = Manifest(\n",
    "    client_name=\"toma\",\n",
    "    engine=\"Together-gpt-JT-6B-v1\",\n",
    "    max_tokens=150,\n",
    "    top_p=0.9,\n",
    "    top_k=40,\n",
    "    stop_sequences=[\"\\n\"],\n",
    ")\n",
    "\n",
    "chatgpt_chain = LLMChain(\n",
    "    llm=ManifestWrapper(client=manifest), \n",
    "    prompt=prompt, \n",
    "    verbose=True, \n",
    "    memory=ConversationalBufferWindowMemory(k=8),\n",
    ")\n",
    "\n",
    "output = chatgpt_chain.predict(human_input=\"Classes are \\\"positive\\\" and \\\"negative\\\". For example given\\nInput: I love this product!\\nOutput: positive.\\nI think this movie was one of the worst of the year. Script was boring!\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "4ef711d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "Human: Classes are \"positive\" and \"negative\". For example given\n",
      "Input: I love this product!\n",
      "Output: positive.\n",
      "I think this movie was one of the worst of the year. Script was boring!\n",
      "AI: negative.\n",
      "Input: So awesome! I wish I could have gone\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "positive.\n"
     ]
    }
   ],
   "source": [
    "output = chatgpt_chain.predict(human_input=\"So awesome! I wish I could have gone\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "a5d6dac2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "Human: Classes are \"positive\" and \"negative\". For example given\n",
      "Input: I love this product!\n",
      "Output: positive.\n",
      "I think this movie was one of the worst of the year. Script was boring!\n",
      "AI: negative.\n",
      "Human: So awesome! I wish I could have gone\n",
      "AI: positive.\n",
      "Input: Hate it.\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "negative.\n"
     ]
    }
   ],
   "source": [
    "output = chatgpt_chain.predict(human_input=\"Hate it.\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "b9283077",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "\n",
      "Input: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
      "Input: This fruit rippens off of the tree.\n",
      "Output: banana.\n",
      "Often comes in bosc and bartlett varieties.\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "apple.\n"
     ]
    }
   ],
   "source": [
    "chatgpt_chain.memory.clear()\n",
    "output = chatgpt_chain.predict(human_input=\"Classes are fruits \\\"apple\\\", \\\"banana\\\", \\\"orange\\\", \\\"pear\\\". For example given\\nInput: This fruit rippens off of the tree.\\nOutput: banana.\\nOften comes in bosc and bartlett varieties.\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "cd0a23d9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "Human: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
      "Input: This fruit rippens off of the tree.\n",
      "Output: banana.\n",
      "Often comes in bosc and bartlett varieties.\n",
      "AI: apple.\n",
      "Input: Often associated with monkeys\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "banana.\n"
     ]
    }
   ],
   "source": [
    "output = chatgpt_chain.predict(human_input=\"Often associated with monkeys\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "90db6eb2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "Human: Classes are fruits \"apple\", \"banana\", \"orange\", \"pear\". For example given\n",
      "Input: This fruit rippens off of the tree.\n",
      "Output: banana.\n",
      "Often comes in bosc and bartlett varieties.\n",
      "AI: apple.\n",
      "Human: Often associated with monkeys\n",
      "AI: banana.\n",
      "Input: Is the color red and often delicious.\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "apple.\n"
     ]
    }
   ],
   "source": [
    "output = chatgpt_chain.predict(human_input=\"Is the color red and often delicious.\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "c3806f89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "\n",
      "Input: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
      "Input: The color of a school bus.\n",
      "Output: yellow.\n",
      "Is the color of the sky\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "blue.\n"
     ]
    }
   ],
   "source": [
    "chatgpt_chain.memory.clear()\n",
    "output = chatgpt_chain.predict(human_input=\"Classes are colors \\\"red\\\", \\\"green\\\", \\\"blue\\\", \\\"yellow\\\". For example given\\nInput: The color of a school bus.\\nOutput: yellow.\\nIs the color of the sky\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "f508f597",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
      "Input: The color of a school bus.\n",
      "Output: yellow.\n",
      "Is the color of the sky\n",
      "AI: blue.\n",
      "Input: Color of a banana.\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "yellow.\n"
     ]
    }
   ],
   "source": [
    "output = chatgpt_chain.predict(human_input=\"Color of a banana.\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "cbd607f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
      "Input: The color of a school bus.\n",
      "Output: yellow.\n",
      "Is the color of the sky\n",
      "AI: blue.\n",
      "Human: Color of a banana.\n",
      "AI: yellow.\n",
      "Input: When someone is sick they are this color.\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "green.\n"
     ]
    }
   ],
   "source": [
    "output = chatgpt_chain.predict(human_input=\"When someone is sick they are this color.\")\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "d33e0e28",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new LLMChain chain...\u001b[0m\n",
      "Prompt after formatting:\n",
      "\u001b[32;1m\u001b[1;3mI am a classification model. It will try to classify your input.\n",
      "\n",
      "Human: Classes are colors \"red\", \"green\", \"blue\", \"yellow\". For example given\n",
      "Input: The color of a school bus.\n",
      "Output: yellow.\n",
      "Is the color of the sky\n",
      "AI: blue.\n",
      "Human: Color of a banana.\n",
      "AI: yellow.\n",
      "Human: When someone is sick they are this color.\n",
      "AI: green.\n",
      "Input: Color of anger.\n",
      "Output:\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished LLMChain chain.\u001b[0m\n",
      "red.\n"
     ]
    }
   ],
   "source": [
    "output = chatgpt_chain.predict(human_input=\"Color of anger.\")\n",
    "print(output)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "bootleg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12 | packaged by conda-forge | (default, Jan 30 2022, 23:36:06) \n[Clang 11.1.0 ]"
  },
  "vscode": {
   "interpreter": {
    "hash": "7a3f97ab0465937066e9b79893b779dfc8a12d73c41f9d98a7bf05133c798250"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}