File size: 6,694 Bytes
cfd3735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "593f7553-7038-498e-96d4-8255e5ce34f0",
   "metadata": {},
   "source": [
    "# Creating a custom Chain\n",
    "\n",
    "To implement your own custom chain you can subclass `Chain` and implement the following methods:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c19c736e-ca74-4726-bb77-0a849bcc2960",
   "metadata": {
    "tags": [],
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [],
   "source": [
    "from __future__ import annotations\n",
    "\n",
    "from typing import Any, Dict, List, Optional\n",
    "\n",
    "from pydantic import Extra\n",
    "\n",
    "from langchain.base_language import BaseLanguageModel\n",
    "from langchain.callbacks.manager import (\n",
    "    AsyncCallbackManagerForChainRun,\n",
    "    CallbackManagerForChainRun,\n",
    ")\n",
    "from langchain.chains.base import Chain\n",
    "from langchain.prompts.base import BasePromptTemplate\n",
    "\n",
    "\n",
    "class MyCustomChain(Chain):\n",
    "    \"\"\"\n",
    "    An example of a custom chain.\n",
    "    \"\"\"\n",
    "\n",
    "    prompt: BasePromptTemplate\n",
    "    \"\"\"Prompt object to use.\"\"\"\n",
    "    llm: BaseLanguageModel\n",
    "    output_key: str = \"text\"  #: :meta private:\n",
    "\n",
    "    class Config:\n",
    "        \"\"\"Configuration for this pydantic object.\"\"\"\n",
    "\n",
    "        extra = Extra.forbid\n",
    "        arbitrary_types_allowed = True\n",
    "\n",
    "    @property\n",
    "    def input_keys(self) -> List[str]:\n",
    "        \"\"\"Will be whatever keys the prompt expects.\n",
    "\n",
    "        :meta private:\n",
    "        \"\"\"\n",
    "        return self.prompt.input_variables\n",
    "\n",
    "    @property\n",
    "    def output_keys(self) -> List[str]:\n",
    "        \"\"\"Will always return text key.\n",
    "\n",
    "        :meta private:\n",
    "        \"\"\"\n",
    "        return [self.output_key]\n",
    "\n",
    "    def _call(\n",
    "        self,\n",
    "        inputs: Dict[str, Any],\n",
    "        run_manager: Optional[CallbackManagerForChainRun] = None,\n",
    "    ) -> Dict[str, str]:\n",
    "        # Your custom chain logic goes here\n",
    "        # This is just an example that mimics LLMChain\n",
    "        prompt_value = self.prompt.format_prompt(**inputs)\n",
    "        \n",
    "        # Whenever you call a language model, or another chain, you should pass\n",
    "        # a callback manager to it. This allows the inner run to be tracked by\n",
    "        # any callbacks that are registered on the outer run.\n",
    "        # You can always obtain a callback manager for this by calling\n",
    "        # `run_manager.get_child()` as shown below.\n",
    "        response = self.llm.generate_prompt(\n",
    "            [prompt_value],\n",
    "            callbacks=run_manager.get_child() if run_manager else None\n",
    "        )\n",
    "\n",
    "        # If you want to log something about this run, you can do so by calling\n",
    "        # methods on the `run_manager`, as shown below. This will trigger any\n",
    "        # callbacks that are registered for that event.\n",
    "        if run_manager:\n",
    "            run_manager.on_text(\"Log something about this run\")\n",
    "        \n",
    "        return {self.output_key: response.generations[0][0].text}\n",
    "\n",
    "    async def _acall(\n",
    "        self,\n",
    "        inputs: Dict[str, Any],\n",
    "        run_manager: Optional[AsyncCallbackManagerForChainRun] = None,\n",
    "    ) -> Dict[str, str]:\n",
    "        # Your custom chain logic goes here\n",
    "        # This is just an example that mimics LLMChain\n",
    "        prompt_value = self.prompt.format_prompt(**inputs)\n",
    "        \n",
    "        # Whenever you call a language model, or another chain, you should pass\n",
    "        # a callback manager to it. This allows the inner run to be tracked by\n",
    "        # any callbacks that are registered on the outer run.\n",
    "        # You can always obtain a callback manager for this by calling\n",
    "        # `run_manager.get_child()` as shown below.\n",
    "        response = await self.llm.agenerate_prompt(\n",
    "            [prompt_value],\n",
    "            callbacks=run_manager.get_child() if run_manager else None\n",
    "        )\n",
    "\n",
    "        # If you want to log something about this run, you can do so by calling\n",
    "        # methods on the `run_manager`, as shown below. This will trigger any\n",
    "        # callbacks that are registered for that event.\n",
    "        if run_manager:\n",
    "            await run_manager.on_text(\"Log something about this run\")\n",
    "        \n",
    "        return {self.output_key: response.generations[0][0].text}\n",
    "\n",
    "    @property\n",
    "    def _chain_type(self) -> str:\n",
    "        return \"my_custom_chain\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "18361f89",
   "metadata": {
    "vscode": {
     "languageId": "python"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new MyCustomChain chain...\u001b[0m\n",
      "Log something about this run\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'Why did the callback function feel lonely? Because it was always waiting for someone to call it back!'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from langchain.callbacks.stdout import StdOutCallbackHandler\n",
    "from langchain.chat_models.openai import ChatOpenAI\n",
    "from langchain.prompts.prompt import PromptTemplate\n",
    "\n",
    "\n",
    "chain = MyCustomChain(\n",
    "    prompt=PromptTemplate.from_template('tell us a joke about {topic}'),\n",
    "    llm=ChatOpenAI()\n",
    ")\n",
    "\n",
    "chain.run({'topic': 'callbacks'}, callbacks=[StdOutCallbackHandler()])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}