Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| # mypy: ignore-errors | |
| from __future__ import annotations | |
| import io | |
| import json | |
| import logging | |
| from typing import TYPE_CHECKING, Any | |
| import boto3 # type: ignore | |
| from llama_index.bridge.pydantic import Field | |
| from llama_index.llms import ( | |
| CompletionResponse, | |
| CustomLLM, | |
| LLMMetadata, | |
| ) | |
| from llama_index.llms.base import ( | |
| llm_chat_callback, | |
| llm_completion_callback, | |
| ) | |
| from llama_index.llms.generic_utils import ( | |
| completion_response_to_chat_response, | |
| stream_completion_response_to_chat_response, | |
| ) | |
| from llama_index.llms.llama_utils import ( | |
| completion_to_prompt as generic_completion_to_prompt, | |
| ) | |
| from llama_index.llms.llama_utils import ( | |
| messages_to_prompt as generic_messages_to_prompt, | |
| ) | |
| if TYPE_CHECKING: | |
| from collections.abc import Sequence | |
| from llama_index.callbacks import CallbackManager | |
| from llama_index.llms import ( | |
| ChatMessage, | |
| ChatResponse, | |
| ChatResponseGen, | |
| CompletionResponseGen, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class LineIterator: | |
| r"""A helper class for parsing the byte stream input from TGI container. | |
| The output of the model will be in the following format: | |
| ``` | |
| b'data:{"token": {"text": " a"}}\n\n' | |
| b'data:{"token": {"text": " challenging"}}\n\n' | |
| b'data:{"token": {"text": " problem" | |
| b'}}' | |
| ... | |
| ``` | |
| While usually each PayloadPart event from the event stream will contain a byte array | |
| with a full json, this is not guaranteed and some of the json objects may be split | |
| across PayloadPart events. For example: | |
| ``` | |
| {'PayloadPart': {'Bytes': b'{"outputs": '}} | |
| {'PayloadPart': {'Bytes': b'[" problem"]}\n'}} | |
| ``` | |
| This class accounts for this by concatenating bytes written via the 'write' function | |
| and then exposing a method which will return lines (ending with a '\n' character) | |
| within the buffer via the 'scan_lines' function. It maintains the position of the | |
| last read position to ensure that previous bytes are not exposed again. It will | |
| also save any pending lines that doe not end with a '\n' to make sure truncations | |
| are concatinated | |
| """ | |
| def __init__(self, stream: Any) -> None: | |
| """Line iterator initializer.""" | |
| self.byte_iterator = iter(stream) | |
| self.buffer = io.BytesIO() | |
| self.read_pos = 0 | |
| def __iter__(self) -> Any: | |
| """Self iterator.""" | |
| return self | |
| def __next__(self) -> Any: | |
| """Next element from iterator.""" | |
| while True: | |
| self.buffer.seek(self.read_pos) | |
| line = self.buffer.readline() | |
| if line and line[-1] == ord("\n"): | |
| self.read_pos += len(line) | |
| return line[:-1] | |
| try: | |
| chunk = next(self.byte_iterator) | |
| except StopIteration: | |
| if self.read_pos < self.buffer.getbuffer().nbytes: | |
| continue | |
| raise | |
| if "PayloadPart" not in chunk: | |
| logger.warning("Unknown event type=%s", chunk) | |
| continue | |
| self.buffer.seek(0, io.SEEK_END) | |
| self.buffer.write(chunk["PayloadPart"]["Bytes"]) | |
| class SagemakerLLM(CustomLLM): | |
| """Sagemaker Inference Endpoint models. | |
| To use, you must supply the endpoint name from your deployed | |
| Sagemaker model & the region where it is deployed. | |
| To authenticate, the AWS client uses the following methods to | |
| automatically load credentials: | |
| https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html | |
| If a specific credential profile should be used, you must pass | |
| the name of the profile from the ~/.aws/credentials file that is to be used. | |
| Make sure the credentials / roles used have the required policies to | |
| access the Sagemaker endpoint. | |
| See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html | |
| """ | |
| endpoint_name: str = Field(description="") | |
| temperature: float = Field(description="The temperature to use for sampling.") | |
| max_new_tokens: int = Field(description="The maximum number of tokens to generate.") | |
| context_window: int = Field( | |
| description="The maximum number of context tokens for the model." | |
| ) | |
| messages_to_prompt: Any = Field( | |
| description="The function to convert messages to a prompt.", exclude=True | |
| ) | |
| completion_to_prompt: Any = Field( | |
| description="The function to convert a completion to a prompt.", exclude=True | |
| ) | |
| generate_kwargs: dict[str, Any] = Field( | |
| default_factory=dict, description="Kwargs used for generation." | |
| ) | |
| model_kwargs: dict[str, Any] = Field( | |
| default_factory=dict, description="Kwargs used for model initialization." | |
| ) | |
| verbose: bool = Field(description="Whether to print verbose output.") | |
| _boto_client: Any = boto3.client( | |
| "sagemaker-runtime", | |
| ) # TODO make it an optional field | |
| def __init__( | |
| self, | |
| endpoint_name: str | None = "", | |
| temperature: float = 0.1, | |
| max_new_tokens: int = 512, # to review defaults | |
| context_window: int = 2048, # to review defaults | |
| messages_to_prompt: Any = None, | |
| completion_to_prompt: Any = None, | |
| callback_manager: CallbackManager | None = None, | |
| generate_kwargs: dict[str, Any] | None = None, | |
| model_kwargs: dict[str, Any] | None = None, | |
| verbose: bool = True, | |
| ) -> None: | |
| """SagemakerLLM initializer.""" | |
| model_kwargs = model_kwargs or {} | |
| model_kwargs.update({"n_ctx": context_window, "verbose": verbose}) | |
| messages_to_prompt = messages_to_prompt or generic_messages_to_prompt | |
| completion_to_prompt = completion_to_prompt or generic_completion_to_prompt | |
| generate_kwargs = generate_kwargs or {} | |
| generate_kwargs.update( | |
| {"temperature": temperature, "max_tokens": max_new_tokens} | |
| ) | |
| super().__init__( | |
| endpoint_name=endpoint_name, | |
| temperature=temperature, | |
| context_window=context_window, | |
| max_new_tokens=max_new_tokens, | |
| messages_to_prompt=messages_to_prompt, | |
| completion_to_prompt=completion_to_prompt, | |
| callback_manager=callback_manager, | |
| generate_kwargs=generate_kwargs, | |
| model_kwargs=model_kwargs, | |
| verbose=verbose, | |
| ) | |
| def inference_params(self): | |
| # TODO expose the rest of params | |
| return { | |
| "do_sample": True, | |
| "top_p": 0.7, | |
| "temperature": self.temperature, | |
| "top_k": 50, | |
| "max_new_tokens": self.max_new_tokens, | |
| } | |
| def metadata(self) -> LLMMetadata: | |
| """Get LLM metadata.""" | |
| return LLMMetadata( | |
| context_window=self.context_window, | |
| num_output=self.max_new_tokens, | |
| model_name="Sagemaker LLama 2", | |
| ) | |
| def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: | |
| self.generate_kwargs.update({"stream": False}) | |
| is_formatted = kwargs.pop("formatted", False) | |
| if not is_formatted: | |
| prompt = self.completion_to_prompt(prompt) | |
| request_params = { | |
| "inputs": prompt, | |
| "stream": False, | |
| "parameters": self.inference_params, | |
| } | |
| resp = self._boto_client.invoke_endpoint( | |
| EndpointName=self.endpoint_name, | |
| Body=json.dumps(request_params), | |
| ContentType="application/json", | |
| ) | |
| response_body = resp["Body"] | |
| response_str = response_body.read().decode("utf-8") | |
| response_dict = eval(response_str) | |
| return CompletionResponse( | |
| text=response_dict[0]["generated_text"][len(prompt) :], raw=resp | |
| ) | |
| def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: | |
| def get_stream(): | |
| text = "" | |
| request_params = { | |
| "inputs": prompt, | |
| "stream": True, | |
| "parameters": self.inference_params, | |
| } | |
| resp = self._boto_client.invoke_endpoint_with_response_stream( | |
| EndpointName=self.endpoint_name, | |
| Body=json.dumps(request_params), | |
| ContentType="application/json", | |
| ) | |
| event_stream = resp["Body"] | |
| start_json = b"{" | |
| stop_token = "<|endoftext|>" | |
| for line in LineIterator(event_stream): | |
| if line != b"" and start_json in line: | |
| data = json.loads(line[line.find(start_json) :].decode("utf-8")) | |
| if data["token"]["text"] != stop_token: | |
| delta = data["token"]["text"] | |
| text += delta | |
| yield CompletionResponse(delta=delta, text=text, raw=data) | |
| return get_stream() | |
| def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: | |
| prompt = self.messages_to_prompt(messages) | |
| completion_response = self.complete(prompt, formatted=True, **kwargs) | |
| return completion_response_to_chat_response(completion_response) | |
| def stream_chat( | |
| self, messages: Sequence[ChatMessage], **kwargs: Any | |
| ) -> ChatResponseGen: | |
| prompt = self.messages_to_prompt(messages) | |
| completion_response = self.stream_complete(prompt, formatted=True, **kwargs) | |
| return stream_completion_response_to_chat_response(completion_response) | |