File size: 4,908 Bytes
cfd3735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMakerEndpoint\n",
    "\n",
    "[Amazon SageMaker](https://aws.amazon.com/sagemaker/) is a system that can build, train, and deploy machine learning (ML) models for any use case with fully managed infrastructure, tools, and workflows.\n",
    "\n",
    "This notebooks goes over how to use an LLM hosted on a `SageMaker endpoint`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip3 install langchain boto3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You have to set up following required parameters of the `SagemakerEndpoint` call:\n",
    "- `endpoint_name`: The name of the endpoint from the deployed Sagemaker model.\n",
    "    Must be unique within an AWS Region.\n",
    "- `credentials_profile_name`: The name of the profile in the ~/.aws/credentials or ~/.aws/config files, which\n",
    "    has either access keys or role information specified.\n",
    "    If not specified, the default credential profile or, if on an EC2 instance,\n",
    "    credentials from IMDS will be used.\n",
    "    See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from langchain.docstore.document import Document"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "example_doc_1 = \"\"\"\n",
    "Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.\n",
    "Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.\n",
    "Therefore, Peter stayed with her at the hospital for 3 days without leaving.\n",
    "\"\"\"\n",
    "\n",
    "docs = [\n",
    "    Document(\n",
    "        page_content=example_doc_1,\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "from langchain import PromptTemplate, SagemakerEndpoint\n",
    "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
    "from langchain.chains.question_answering import load_qa_chain\n",
    "import json\n",
    "\n",
    "query = \"\"\"How long was Elizabeth hospitalized?\n",
    "\"\"\"\n",
    "\n",
    "prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n",
    "\n",
    "{context}\n",
    "\n",
    "Question: {question}\n",
    "Answer:\"\"\"\n",
    "PROMPT = PromptTemplate(\n",
    "    template=prompt_template, input_variables=[\"context\", \"question\"]\n",
    ")\n",
    "\n",
    "class ContentHandler(ContentHandlerBase):\n",
    "    content_type = \"application/json\"\n",
    "    accepts = \"application/json\"\n",
    "\n",
    "    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
    "        input_str = json.dumps({prompt: prompt, **model_kwargs})\n",
    "        return input_str.encode('utf-8')\n",
    "    \n",
    "    def transform_output(self, output: bytes) -> str:\n",
    "        response_json = json.loads(output.read().decode(\"utf-8\"))\n",
    "        return response_json[0][\"generated_text\"]\n",
    "\n",
    "content_handler = ContentHandler()\n",
    "\n",
    "chain = load_qa_chain(\n",
    "    llm=SagemakerEndpoint(\n",
    "        endpoint_name=\"endpoint-name\", \n",
    "        credentials_profile_name=\"credentials-profile-name\", \n",
    "        region_name=\"us-west-2\", \n",
    "        model_kwargs={\"temperature\":1e-10},\n",
    "        content_handler=content_handler\n",
    "    ),\n",
    "    prompt=PROMPT\n",
    ")\n",
    "\n",
    "chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}