muryshev's picture
update
0dffae9
raw
history blame
17.9 kB
import json
from typing import AsyncGenerator, Optional, List
import httpx
import logging
from transformers import AutoTokenizer
from components.llm.utils import convert_to_openai_format
from components.llm.common import ChatRequest, LlmParams, LlmApi, LlmPredictParams
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(message)s",
)
class DeepInfraApi(LlmApi):
"""
Класс для работы с API vllm.
"""
def __init__(self, params: LlmParams):
super().__init__()
super().set_params(params)
print('Tokenizer initialization.')
# self.tokenizer = AutoTokenizer.from_pretrained(params.tokenizer if params.tokenizer is not None else params.model)
print(f"Tokenizer initialized for model {params.model}.")
async def get_models(self) -> List[str]:
"""
Выполняет GET-запрос к API для получения списка доступных моделей.
Возвращает:
list[str]: Список идентификаторов моделей.
Если произошла ошибка или данные недоступны, возвращается пустой список.
Исключения:
Все ошибки HTTP-запросов логируются в консоль, но не выбрасываются дальше.
"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.params.url}/v1/openai/models", headers=super().create_headers())
if response.status_code == 200:
json_data = response.json()
return [item['id'] for item in json_data.get('data', [])]
except httpx.RequestError as error:
print('Error fetching models:', error)
return []
def create_messages(self, prompt: str, system_prompt: str = None) -> List[dict]:
"""
Создает сообщения для LLM на основе переданного промпта и системного промпта (если он задан).
Args:
prompt (str): Пользовательский промпт.
Returns:
list[dict]: Список сообщений с ролями и содержимым.
"""
actual_prompt = self.apply_llm_template_to_prompt(prompt)
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
else:
if self.params.predict_params and self.params.predict_params.system_prompt:
messages.append({"role": "system", "content": self.params.predict_params.system_prompt})
messages.append({"role": "user", "content": actual_prompt})
return messages
def apply_llm_template_to_prompt(self, prompt: str) -> str:
"""
Применяет шаблон LLM к переданному промпту, если он задан.
Args:
prompt (str): Пользовательский промпт.
Returns:
str: Промпт с примененным шаблоном (или оригинальный, если шаблон отсутствует).
"""
actual_prompt = prompt
if self.params.template is not None:
actual_prompt = self.params.template.replace("{{PROMPT}}", actual_prompt)
return actual_prompt
async def tokenize(self, prompt: str) -> Optional[dict]:
"""
Токенизирует входной текстовый промпт.
Args:
prompt (str): Текст, который нужно токенизировать.
Returns:
dict: Словарь с токенами и их количеством или None в случае ошибки.
"""
try:
tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
return {"result": tokens, "num_tokens": len(tokens), "max_length": self.params.context_length}
except Exception as e:
print(f"Tokenization error: {e}")
return None
async def detokenize(self, tokens: List[int]) -> Optional[str]:
"""
Детокенизирует список токенов обратно в строку.
Args:
tokens (List[int]): Список токенов, который нужно преобразовать в текст.
Returns:
str: Восстановленный текст или None в случае ошибки.
"""
try:
text = self.tokenizer.decode(tokens, skip_special_tokens=True)
return text
except Exception as e:
print(f"Detokenization error: {e}")
return None
def create_chat_request(self, chat_request: ChatRequest, system_prompt, params: LlmPredictParams) -> dict:
"""
Создает запрос для предсказания на основе параметров LLM.
Args:
prompt (str): Промпт для запроса.
Returns:
dict: Словарь с параметрами для выполнения запроса.
"""
request = {
"stream": False,
"model": self.params.model,
}
predict_params = params
if predict_params:
if predict_params.stop:
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
if non_empty_stop:
request["stop"] = non_empty_stop
if predict_params.n_predict is not None:
request["max_tokens"] = int(predict_params.n_predict or 0)
request["temperature"] = float(predict_params.temperature or 0)
if predict_params.top_k is not None:
request["top_k"] = int(predict_params.top_k)
if predict_params.top_p is not None:
request["top_p"] = float(predict_params.top_p)
if predict_params.min_p is not None:
request["min_p"] = float(predict_params.min_p)
if predict_params.seed is not None:
request["seed"] = int(predict_params.seed)
if predict_params.n_keep is not None:
request["n_keep"] = int(predict_params.n_keep)
if predict_params.cache_prompt is not None:
request["cache_prompt"] = bool(predict_params.cache_prompt)
if predict_params.repeat_penalty is not None:
request["repetition_penalty"] = float(predict_params.repeat_penalty)
if predict_params.repeat_last_n is not None:
request["repeat_last_n"] = int(predict_params.repeat_last_n)
if predict_params.presence_penalty is not None:
request["presence_penalty"] = float(predict_params.presence_penalty)
if predict_params.frequency_penalty is not None:
request["frequency_penalty"] = float(predict_params.frequency_penalty)
request["messages"] = convert_to_openai_format(chat_request, system_prompt)
return request
async def create_request(self, prompt: str, system_prompt: str = None) -> dict:
"""
Создает запрос для предсказания на основе параметров LLM.
Args:
prompt (str): Промпт для запроса.
Returns:
dict: Словарь с параметрами для выполнения запроса.
"""
request = {
"stream": False,
"model": self.params.model,
}
predict_params = self.params.predict_params
if predict_params:
if predict_params.stop:
non_empty_stop = list(filter(lambda o: o != "", predict_params.stop))
if non_empty_stop:
request["stop"] = non_empty_stop
if predict_params.n_predict is not None:
request["max_tokens"] = int(predict_params.n_predict or 0)
request["temperature"] = float(predict_params.temperature or 0)
if predict_params.top_k is not None:
request["top_k"] = int(predict_params.top_k)
if predict_params.top_p is not None:
request["top_p"] = float(predict_params.top_p)
if predict_params.min_p is not None:
request["min_p"] = float(predict_params.min_p)
if predict_params.seed is not None:
request["seed"] = int(predict_params.seed)
if predict_params.n_keep is not None:
request["n_keep"] = int(predict_params.n_keep)
if predict_params.cache_prompt is not None:
request["cache_prompt"] = bool(predict_params.cache_prompt)
if predict_params.repeat_penalty is not None:
request["repetition_penalty"] = float(predict_params.repeat_penalty)
if predict_params.repeat_last_n is not None:
request["repeat_last_n"] = int(predict_params.repeat_last_n)
if predict_params.presence_penalty is not None:
request["presence_penalty"] = float(predict_params.presence_penalty)
if predict_params.frequency_penalty is not None:
request["frequency_penalty"] = float(predict_params.frequency_penalty)
request["messages"] = self.create_messages(prompt, system_prompt)
return request
async def trim_sources(self, sources: str, user_request: str, system_prompt: str = None) -> dict:
raise NotImplementedError("This function is not supported.")
async def predict_chat(self, request: ChatRequest, system_prompt, params: LlmPredictParams) -> str:
"""
Выполняет запрос к API и возвращает результат.
Args:
prompt (str): Входной текст для предсказания.
Returns:
str: Сгенерированный текст.
"""
async with httpx.AsyncClient() as client:
request = self.create_chat_request(request, system_prompt, params)
response = await client.post(f"{self.params.url}/v1/openai/chat/completions", headers=super().create_headers(), json=request, timeout=httpx.Timeout(connect=5.0, read=60.0, write=180, pool=10))
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
else:
logging.error(f"Request failed: status code {response.status_code}")
logging.error(response.text)
async def predict_chat_stream(self, request: ChatRequest, system_prompt, params: LlmPredictParams) -> str:
"""
Выполняет запрос к API с поддержкой потокового вывода (SSE) и возвращает результат.
Args:
prompt (str): Входной текст для предсказания.
Returns:
str: Сгенерированный текст.
"""
async with httpx.AsyncClient() as client:
request = self.create_chat_request(request, system_prompt, params)
request["stream"] = True
print(super().create_headers())
async with client.stream("POST", f"{self.params.url}/v1/openai/chat/completions", json=request, headers=super().create_headers()) as response:
if response.status_code != 200:
# Если ошибка, читаем ответ для получения подробностей
error_content = await response.aread()
raise Exception(f"API error: {error_content.decode('utf-8')}")
# Для хранения результата
generated_text = ""
# Асинхронное чтение построчно
async for line in response.aiter_lines():
if line.startswith("data: "): # SSE-сообщения начинаются с "data: "
try:
# Парсим JSON из строки
data = json.loads(line[len("data: "):].strip())
if data == "[DONE]": # Конец потока
break
if "choices" in data and data["choices"]:
# Получаем текст из текущего токена
token_value = data["choices"][0].get("delta", {}).get("content", "")
generated_text += token_value
except json.JSONDecodeError:
continue # Игнорируем строки, которые не удается декодировать
return generated_text.strip()
async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str,
params: LlmPredictParams) -> AsyncGenerator[str, None]:
"""
Выполняет потоковый запрос к API и возвращает токены по мере их генерации.
Args:
request (ChatRequest): История чата.
system_prompt (str): Системный промпт.
params (LlmPredictParams): Параметры предсказания.
Yields:
str: Токены ответа LLM.
"""
params
async with httpx.AsyncClient() as client:
request_data = self.create_chat_request(request, system_prompt, params)
request_data["stream"] = True
async with client.stream(
"POST",
f"{self.params.url}/v1/openai/chat/completions",
json=request_data,
headers=super().create_headers()
) as response:
if response.status_code != 200:
error_content = await response.aread()
raise Exception(f"API error: {error_content.decode('utf-8')}")
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
data = json.loads(line[len("data: "):].strip())
if data == "[DONE]":
break
if "choices" in data and data["choices"]:
token_value = data["choices"][0].get("delta", {}).get("content", "")
if token_value:
yield token_value
except json.JSONDecodeError:
continue
async def predict(self, prompt: str, system_prompt: str) -> str:
"""
Выполняет запрос к API и возвращает результат.
Args:
prompt (str): Входной текст для предсказания.
Returns:
str: Сгенерированный текст.
"""
async with httpx.AsyncClient() as client:
request = await self.create_request(prompt, system_prompt)
response = await client.post(f"{self.params.url}/v1/openai/chat/completions", headers=super().create_headers(), json=request, timeout=httpx.Timeout(connect=5.0, read=60.0, write=180, pool=10))
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
else:
logging.info(f"Request {prompt} failed: status code {response.status_code}")
logging.info(response.text)
async def trim_prompt(self, prompt: str, system_prompt: str = None):
result = await self.tokenize(prompt)
result_system = None
system_prompt_length = 0
if system_prompt is not None:
result_system = await self.tokenize(system_prompt)
if result_system is not None:
system_prompt_length = len(result_system["result"])
# в случае ошибки при токенизации, вернем исходную строку безопасной длины
if result["result"] is None or (system_prompt is not None and result_system is None):
return prompt[int(self.params.context_length / 3)]
#вероятно, часть уходит на форматирование чата, надо проверить
max_length = result["max_length"] - len(result["result"]) - system_prompt_length - self.params.predict_params.n_predict
detokenized_str = await self.detokenize(result["result"][:max_length])
# в случае ошибки при детокенизации, вернем исходную строку безопасной длины
if detokenized_str is None:
return prompt[self.params.context_length / 3]
return detokenized_str