Spaces:
Running
Running
File size: 3,590 Bytes
cfd3735 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
{
"cells": [
{
"cell_type": "markdown",
"id": "1f83f273",
"metadata": {},
"source": [
"# SageMaker Endpoint Embeddings\n",
"\n",
"Let's load the SageMaker Endpoints Embeddings class. The class can be used if you host, e.g. your own Hugging Face model on SageMaker.\n",
"\n",
"For instructions on how to do this, please see [here](https://www.philschmid.de/custom-inference-huggingface-sagemaker). **Note**: In order to handle batched requests, you will need to adjust the return line in the `predict_fn()` function within the custom `inference.py` script:\n",
"\n",
"Change from\n",
"\n",
"`return {\"vectors\": sentence_embeddings[0].tolist()}`\n",
"\n",
"to:\n",
"\n",
"`return {\"vectors\": sentence_embeddings.tolist()}`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88d366bd",
"metadata": {},
"outputs": [],
"source": [
"!pip3 install langchain boto3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1e9b926a",
"metadata": {},
"outputs": [],
"source": [
"from typing import Dict, List\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"import json\n",
"\n",
"\n",
"class ContentHandler(ContentHandlerBase):\n",
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, inputs: list[str], model_kwargs: Dict) -> bytes:\n",
" input_str = json.dumps({\"inputs\": inputs, **model_kwargs})\n",
" return input_str.encode('utf-8')\n",
"\n",
" def transform_output(self, output: bytes) -> List[List[float]]:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"vectors\"]\n",
"\n",
"content_handler = ContentHandler()\n",
"\n",
"\n",
"embeddings = SagemakerEndpointEmbeddings(\n",
" # endpoint_name=\"endpoint-name\", \n",
" # credentials_profile_name=\"credentials-profile-name\", \n",
" endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n",
" region_name=\"us-east-1\", \n",
" content_handler=content_handler\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fe9797b8",
"metadata": {},
"outputs": [],
"source": [
"query_result = embeddings.embed_query(\"foo\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "76f1b752",
"metadata": {},
"outputs": [],
"source": [
"doc_results = embeddings.embed_documents([\"foo\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fff99b21",
"metadata": {},
"outputs": [],
"source": [
"doc_results"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aaad49f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
},
"vscode": {
"interpreter": {
"hash": "7377c2ccc78bc62c2683122d48c8cd1fb85a53850a1b1fc29736ed39852c9885"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|