Spaces:
Runtime error
Runtime error
| import json | |
| from typing import AsyncGenerator, Optional, List | |
| import httpx | |
| import logging | |
| from transformers import AutoTokenizer | |
| from components.llm.utils import convert_to_openai_format | |
| from components.llm.common import ChatRequest, LlmParams, LlmApi, LlmPredictParams | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format="%(asctime)s - %(message)s", | |
| ) | |
| class DeepInfraApi(LlmApi): | |
| """ | |
| Класс для работы с API vllm. | |
| """ | |
| def __init__(self, params: LlmParams): | |
| super().__init__() | |
| super().set_params(params) | |
| print('Tokenizer initialization.') | |
| # self.tokenizer = AutoTokenizer.from_pretrained(params.tokenizer if params.tokenizer is not None else params.model) | |
| print(f"Tokenizer initialized for model {params.model}.") | |
| async def get_models(self) -> List[str]: | |
| """ | |
| Выполняет GET-запрос к API для получения списка доступных моделей. | |
| Возвращает: | |
| list[str]: Список идентификаторов моделей. | |
| Если произошла ошибка или данные недоступны, возвращается пустой список. | |
| Исключения: | |
| Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше. | |
| """ | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| response = await client.get(f"{self.params.url}/v1/openai/models", headers=super().create_headers()) | |
| if response.status_code == 200: | |
| json_data = response.json() | |
| return [item['id'] for item in json_data.get('data', [])] | |
| except httpx.RequestError as error: | |
| print('Error fetching models:', error) | |
| return [] | |
| def create_messages(self, prompt: str, system_prompt: str = None) -> List[dict]: | |
| """ | |
| Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан). | |
| Args: | |
| prompt (str): Пользовательский промпт. | |
| Returns: | |
| list[dict]: Список сообщений с ролями и содержимым. | |
| """ | |
| actual_prompt = self.apply_llm_template_to_prompt(prompt) | |
| messages = [] | |
| if system_prompt is not None: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| else: | |
| if self.params.predict_params and self.params.predict_params.system_prompt: | |
| messages.append({"role": "system", "content": self.params.predict_params.system_prompt}) | |
| messages.append({"role": "user", "content": actual_prompt}) | |
| return messages | |
| def apply_llm_template_to_prompt(self, prompt: str) -> str: | |
| """ | |
| Применяет шаблон LLM к переданному промпту, если он задан. | |
| Args: | |
| prompt (str): Пользовательский промпт. | |
| Returns: | |
| str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует). | |
| """ | |
| actual_prompt = prompt | |
| if self.params.template is not None: | |
| actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt) | |
| return actual_prompt | |
| async def tokenize(self, prompt: str) -> Optional[dict]: | |
| """ | |
| Токенизирует входной текстовый промпт. | |
| Args: | |
| prompt (str): Текст, который нужно токенизировать. | |
| Returns: | |
| dict: Словарь с токенами и их количеством или None в случае ошибки. | |
| """ | |
| try: | |
| tokens = self.tokenizer.encode(prompt, add_special_tokens=True) | |
| return {"result": tokens, "num_tokens": len(tokens), "max_length": self.params.context_length} | |
| except Exception as e: | |
| print(f"Tokenization error: {e}") | |
| return None | |
| async def detokenize(self, tokens: List[int]) -> Optional[str]: | |
| """ | |
| Детокенизирует список токенов обратно в строку. | |
| Args: | |
| tokens (List[int]): Список токенов, который нужно преобразовать в текст. | |
| Returns: | |
| str: Восстановленный текст или None в случае ошибки. | |
| """ | |
| try: | |
| text = self.tokenizer.decode(tokens, skip_special_tokens=True) | |
| return text | |
| except Exception as e: | |
| print(f"Detokenization error: {e}") | |
| return None | |
| def create_chat_request(self, chat_request: ChatRequest, system_prompt, params: LlmPredictParams) -> dict: | |
| """ | |
| Создает запрос для предсказания на основе параметров LLM. | |
| Args: | |
| prompt (str): Промпт для запроса. | |
| Returns: | |
| dict: Словарь с параметрами для выполнения запроса. | |
| """ | |
| request = { | |
| "stream": False, | |
| "model": self.params.model, | |
| } | |
| predict_params = params | |
| if predict_params: | |
| if predict_params.stop: | |
| non_empty_stop = list(filter(lambda o: o != "", predict_params.stop)) | |
| if non_empty_stop: | |
| request["stop"] = non_empty_stop | |
| if predict_params.n_predict is not None: | |
| request["max_tokens"] = int(predict_params.n_predict or 0) | |
| request["temperature"] = float(predict_params.temperature or 0) | |
| if predict_params.top_k is not None: | |
| request["top_k"] = int(predict_params.top_k) | |
| if predict_params.top_p is not None: | |
| request["top_p"] = float(predict_params.top_p) | |
| if predict_params.min_p is not None: | |
| request["min_p"] = float(predict_params.min_p) | |
| if predict_params.seed is not None: | |
| request["seed"] = int(predict_params.seed) | |
| if predict_params.n_keep is not None: | |
| request["n_keep"] = int(predict_params.n_keep) | |
| if predict_params.cache_prompt is not None: | |
| request["cache_prompt"] = bool(predict_params.cache_prompt) | |
| if predict_params.repeat_penalty is not None: | |
| request["repetition_penalty"] = float(predict_params.repeat_penalty) | |
| if predict_params.repeat_last_n is not None: | |
| request["repeat_last_n"] = int(predict_params.repeat_last_n) | |
| if predict_params.presence_penalty is not None: | |
| request["presence_penalty"] = float(predict_params.presence_penalty) | |
| if predict_params.frequency_penalty is not None: | |
| request["frequency_penalty"] = float(predict_params.frequency_penalty) | |
| request["messages"] = convert_to_openai_format(chat_request, system_prompt) | |
| return request | |
| async def create_request(self, prompt: str, system_prompt: str = None) -> dict: | |
| """ | |
| Создает запрос для предсказания на основе параметров LLM. | |
| Args: | |
| prompt (str): Промпт для запроса. | |
| Returns: | |
| dict: Словарь с параметрами для выполнения запроса. | |
| """ | |
| request = { | |
| "stream": False, | |
| "model": self.params.model, | |
| } | |
| predict_params = self.params.predict_params | |
| if predict_params: | |
| if predict_params.stop: | |
| non_empty_stop = list(filter(lambda o: o != "", predict_params.stop)) | |
| if non_empty_stop: | |
| request["stop"] = non_empty_stop | |
| if predict_params.n_predict is not None: | |
| request["max_tokens"] = int(predict_params.n_predict or 0) | |
| request["temperature"] = float(predict_params.temperature or 0) | |
| if predict_params.top_k is not None: | |
| request["top_k"] = int(predict_params.top_k) | |
| if predict_params.top_p is not None: | |
| request["top_p"] = float(predict_params.top_p) | |
| if predict_params.min_p is not None: | |
| request["min_p"] = float(predict_params.min_p) | |
| if predict_params.seed is not None: | |
| request["seed"] = int(predict_params.seed) | |
| if predict_params.n_keep is not None: | |
| request["n_keep"] = int(predict_params.n_keep) | |
| if predict_params.cache_prompt is not None: | |
| request["cache_prompt"] = bool(predict_params.cache_prompt) | |
| if predict_params.repeat_penalty is not None: | |
| request["repetition_penalty"] = float(predict_params.repeat_penalty) | |
| if predict_params.repeat_last_n is not None: | |
| request["repeat_last_n"] = int(predict_params.repeat_last_n) | |
| if predict_params.presence_penalty is not None: | |
| request["presence_penalty"] = float(predict_params.presence_penalty) | |
| if predict_params.frequency_penalty is not None: | |
| request["frequency_penalty"] = float(predict_params.frequency_penalty) | |
| request["messages"] = self.create_messages(prompt, system_prompt) | |
| return request | |
| async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict: | |
| raise NotImplementedError("This function is not supported.") | |
| async def predict_chat(self, request: ChatRequest, system_prompt, params: LlmPredictParams) -> str: | |
| """ | |
| Выполняет запрос к API и возвращает результат. | |
| Args: | |
| prompt (str): Входной текст для предсказания. | |
| Returns: | |
| str: Сгенерированный текст. | |
| """ | |
| async with httpx.AsyncClient() as client: | |
| request = self.create_chat_request(request, system_prompt, params) | |
| response = await client.post(f"{self.params.url}/v1/openai/chat/completions", headers=super().create_headers(), json=request, timeout=httpx.Timeout(connect=5.0, read=60.0, write=180, pool=10)) | |
| if response.status_code == 200: | |
| return response.json()["choices"][0]["message"]["content"] | |
| else: | |
| logging.error(f"Request failed: status code {response.status_code}") | |
| logging.error(response.text) | |
| async def predict_chat_stream(self, request: ChatRequest, system_prompt, params: LlmPredictParams) -> str: | |
| """ | |
| Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат. | |
| Args: | |
| prompt (str): Входной текст для предсказания. | |
| Returns: | |
| str: Сгенерированный текст. | |
| """ | |
| async with httpx.AsyncClient() as client: | |
| request = self.create_chat_request(request, system_prompt, params) | |
| request["stream"] = True | |
| print(super().create_headers()) | |
| async with client.stream("POST", f"{self.params.url}/v1/openai/chat/completions", json=request, headers=super().create_headers()) as response: | |
| if response.status_code != 200: | |
| # Если ошибка, читаем ответ для получения подробностей | |
| error_content = await response.aread() | |
| raise Exception(f"API error: {error_content.decode('utf-8')}") | |
| # Для хранения результата | |
| generated_text = "" | |
| # Асинхронное чтение построчно | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): # SSE-сообщения начинаются с "data: " | |
| try: | |
| # Парсим JSON из строки | |
| data = json.loads(line[len("data: "):].strip()) | |
| if data == "[DONE]": # Конец потока | |
| break | |
| if "choices" in data and data["choices"]: | |
| # Получаем текст из текущего токена | |
| token_value = data["choices"][0].get("delta", {}).get("content", "") | |
| generated_text += token_value | |
| except json.JSONDecodeError: | |
| continue # Игнорируем строки, которые не удается декодировать | |
| return generated_text.strip() | |
| async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str, | |
| params: LlmPredictParams) -> AsyncGenerator[str, None]: | |
| """ | |
| Выполняет потоковый запрос к API и возвращает токены по мере их генерации. | |
| Args: | |
| request (ChatRequest): История чата. | |
| system_prompt (str): Системный промпт. | |
| params (LlmPredictParams): Параметры предсказания. | |
| Yields: | |
| str: Токены ответа LLM. | |
| """ | |
| params | |
| async with httpx.AsyncClient() as client: | |
| request_data = self.create_chat_request(request, system_prompt, params) | |
| request_data["stream"] = True | |
| async with client.stream( | |
| "POST", | |
| f"{self.params.url}/v1/openai/chat/completions", | |
| json=request_data, | |
| headers=super().create_headers() | |
| ) as response: | |
| if response.status_code != 200: | |
| error_content = await response.aread() | |
| raise Exception(f"API error: {error_content.decode('utf-8')}") | |
| async for line in response.aiter_lines(): | |
| if line.startswith("data: "): | |
| try: | |
| data = json.loads(line[len("data: "):].strip()) | |
| if data == "[DONE]": | |
| break | |
| if "choices" in data and data["choices"]: | |
| token_value = data["choices"][0].get("delta", {}).get("content", "") | |
| if token_value: | |
| yield token_value | |
| except json.JSONDecodeError: | |
| continue | |
| async def predict(self, prompt: str, system_prompt: str) -> str: | |
| """ | |
| Выполняет запрос к API и возвращает результат. | |
| Args: | |
| prompt (str): Входной текст для предсказания. | |
| Returns: | |
| str: Сгенерированный текст. | |
| """ | |
| async with httpx.AsyncClient() as client: | |
| request = await self.create_request(prompt, system_prompt) | |
| response = await client.post(f"{self.params.url}/v1/openai/chat/completions", headers=super().create_headers(), json=request, timeout=httpx.Timeout(connect=5.0, read=60.0, write=180, pool=10)) | |
| if response.status_code == 200: | |
| return response.json()["choices"][0]["message"]["content"] | |
| else: | |
| logging.info(f"Request {prompt} failed: status code {response.status_code}") | |
| logging.info(response.text) | |
| async def trim_prompt(self, prompt: str, system_prompt: str = None): | |
| result = await self.tokenize(prompt) | |
| result_system = None | |
| system_prompt_length = 0 | |
| if system_prompt is not None: | |
| result_system = await self.tokenize(system_prompt) | |
| if result_system is not None: | |
| system_prompt_length = len(result_system["result"]) | |
| # в случае ошибки при токенизации, вернем исходную строку безопасной длины | |
| if result["result"] is None or (system_prompt is not None and result_system is None): | |
| return prompt[int(self.params.context_length / 3)] | |
| #вероятно, часть уходит на форматирование чата, надо проверить | |
| max_length = result["max_length"] - len(result["result"]) - system_prompt_length - self.params.predict_params.n_predict | |
| detokenized_str = await self.detokenize(result["result"][:max_length]) | |
| # в случае ошибки при детокенизации, вернем исходную строку безопасной длины | |
| if detokenized_str is None: | |
| return prompt[self.params.context_length / 3] | |
| return detokenized_str | |