muryshev commited on
Commit
0dffae9
·
1 Parent(s): e6e0df0
common/decorators.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ def singleton(cls):
3
+ instances = {}
4
+ def get_instance(*args, **kwargs):
5
+ if cls not in instances:
6
+ instances[cls] = cls(*args, **kwargs)
7
+ return instances[cls]
8
+ return get_instance
components/embedding_extraction.py CHANGED
@@ -6,23 +6,27 @@ import torch
6
  import torch.nn.functional as F
7
  from torch.utils.data import DataLoader
8
  from transformers import (AutoModel, AutoTokenizer, BatchEncoding,
9
- XLMRobertaModel)
10
  from transformers.modeling_outputs import \
11
  BaseModelOutputWithPoolingAndCrossAttentions as EncoderOutput
12
 
13
- logger = logging.getLogger(__name__)
14
 
 
15
 
 
16
  class EmbeddingExtractor:
17
  """Класс обрабатывает текст вопроса и возвращает embedding"""
18
 
19
  def __init__(
20
  self,
21
- model_id: str,
22
  device: str | torch.device | None = None,
23
  batch_size: int = 1,
24
  do_normalization: bool = True,
25
  max_len: int = 510,
 
 
26
  ):
27
  """
28
  Класс, соединяющий в себе модель, токенизатор и параметры векторизации.
@@ -33,6 +37,8 @@ class EmbeddingExtractor:
33
  batch_size: Размер батча (по умолчанию - 1).
34
  do_normalization: Нормировать ли вектора (по умолчанию - True).
35
  max_len: Максимальная длина текста в токенах (по умолчанию - 510).
 
 
36
  """
37
  if device is None:
38
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -40,11 +46,19 @@ class EmbeddingExtractor:
40
  device = torch.device(device)
41
 
42
  self.device = device
 
43
  # Инициализация модели
44
- self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
45
- self.model: XLMRobertaModel = AutoModel.from_pretrained(model_id, local_files_only=True).to(
46
- self.device
47
- )
 
 
 
 
 
 
 
48
  self.model.eval()
49
  self.model.share_memory()
50
 
 
6
  import torch.nn.functional as F
7
  from torch.utils.data import DataLoader
8
  from transformers import (AutoModel, AutoTokenizer, BatchEncoding,
9
+ XLMRobertaModel, PreTrainedTokenizer, PreTrainedTokenizerFast)
10
  from transformers.modeling_outputs import \
11
  BaseModelOutputWithPoolingAndCrossAttentions as EncoderOutput
12
 
13
+ from common.decorators import singleton
14
 
15
+ logger = logging.getLogger(__name__)
16
 
17
+ @singleton
18
  class EmbeddingExtractor:
19
  """Класс обрабатывает текст вопроса и возвращает embedding"""
20
 
21
  def __init__(
22
  self,
23
+ model_id: str | None,
24
  device: str | torch.device | None = None,
25
  batch_size: int = 1,
26
  do_normalization: bool = True,
27
  max_len: int = 510,
28
+ model: XLMRobertaModel = None,
29
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast = None
30
  ):
31
  """
32
  Класс, соединяющий в себе модель, токенизатор и параметры векторизации.
 
37
  batch_size: Размер батча (по умолчанию - 1).
38
  do_normalization: Нормировать ли вектора (по умолчанию - True).
39
  max_len: Максимальная длина текста в токенах (по умолчанию - 510).
40
+ model: Экземпляр загруженной модели.
41
+ tokenizer: Экземпляр загруженного токенизатора.
42
  """
43
  if device is None:
44
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
46
  device = torch.device(device)
47
 
48
  self.device = device
49
+
50
  # Инициализация модели
51
+ if model is not None and tokenizer is not None:
52
+ self.tokenizer = tokenizer
53
+ self.model = model
54
+ elif model_id is not None:
55
+ print('EmbeddingExtractor: model loading '+model_id+' to '+str(self.device))
56
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, local_files_only=True)
57
+ self.model: XLMRobertaModel = AutoModel.from_pretrained(model_id, local_files_only=True).to(
58
+ self.device
59
+ )
60
+
61
+ print('EmbeddingExtractor: model loaded')
62
  self.model.eval()
63
  self.model.share_memory()
64
 
components/llm/deepinfra_api.py CHANGED
@@ -1,5 +1,5 @@
1
  import json
2
- from typing import Optional, List
3
  import httpx
4
  import logging
5
  from transformers import AutoTokenizer
@@ -286,7 +286,6 @@ class DeepInfraApi(LlmApi):
286
  try:
287
  # Парсим JSON из строки
288
  data = json.loads(line[len("data: "):].strip())
289
- print(data)
290
  if data == "[DONE]": # Конец потока
291
  break
292
  if "choices" in data and data["choices"]:
@@ -298,6 +297,47 @@ class DeepInfraApi(LlmApi):
298
 
299
  return generated_text.strip()
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  async def predict(self, prompt: str, system_prompt: str) -> str:
302
  """
303
  Выполняет запрос к API и возвращает результат.
 
1
  import json
2
+ from typing import AsyncGenerator, Optional, List
3
  import httpx
4
  import logging
5
  from transformers import AutoTokenizer
 
286
  try:
287
  # Парсим JSON из строки
288
  data = json.loads(line[len("data: "):].strip())
 
289
  if data == "[DONE]": # Конец потока
290
  break
291
  if "choices" in data and data["choices"]:
 
297
 
298
  return generated_text.strip()
299
 
300
+ async def get_predict_chat_generator(self, request: ChatRequest, system_prompt: str,
301
+ params: LlmPredictParams) -> AsyncGenerator[str, None]:
302
+ """
303
+ Выполняет потоковый запрос к API и возвращает токены по мере их генерации.
304
+
305
+ Args:
306
+ request (ChatRequest): История чата.
307
+ system_prompt (str): Системный промпт.
308
+ params (LlmPredictParams): Параметры предсказания.
309
+
310
+ Yields:
311
+ str: Токены ответа LLM.
312
+ """
313
+ params
314
+ async with httpx.AsyncClient() as client:
315
+ request_data = self.create_chat_request(request, system_prompt, params)
316
+ request_data["stream"] = True
317
+
318
+ async with client.stream(
319
+ "POST",
320
+ f"{self.params.url}/v1/openai/chat/completions",
321
+ json=request_data,
322
+ headers=super().create_headers()
323
+ ) as response:
324
+ if response.status_code != 200:
325
+ error_content = await response.aread()
326
+ raise Exception(f"API error: {error_content.decode('utf-8')}")
327
+
328
+ async for line in response.aiter_lines():
329
+ if line.startswith("data: "):
330
+ try:
331
+ data = json.loads(line[len("data: "):].strip())
332
+ if data == "[DONE]":
333
+ break
334
+ if "choices" in data and data["choices"]:
335
+ token_value = data["choices"][0].get("delta", {}).get("content", "")
336
+ if token_value:
337
+ yield token_value
338
+ except json.JSONDecodeError:
339
+ continue
340
+
341
  async def predict(self, prompt: str, system_prompt: str) -> str:
342
  """
343
  Выполняет запрос к API и возвращает результат.
routes/llm.py CHANGED
@@ -1,8 +1,11 @@
 
1
  import logging
2
  import os
3
- from typing import Annotated, Optional
4
  from uuid import UUID
5
 
 
 
6
  from components.services.dataset import DatasetService
7
  from components.services.entity import EntityService
8
  from fastapi import APIRouter, Depends, HTTPException
@@ -42,6 +45,97 @@ llm_params = LlmParams(
42
  # TODO: унести в DI
43
  llm_api = DeepInfraApi(params=llm_params)
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  @router.post("/chat")
47
  async def chat(
@@ -68,29 +162,6 @@ async def chat(
68
  stop=[],
69
  )
70
 
71
- # TODO: Вынести
72
- def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]:
73
- return next(
74
- (
75
- msg
76
- for msg in reversed(chat_request.history)
77
- if msg.role == "user"
78
- and (msg.searchResults is None or not msg.searchResults)
79
- ),
80
- None,
81
- )
82
-
83
- def insert_search_results_to_message(
84
- chat_request: ChatRequest, new_content: str
85
- ) -> bool:
86
- for msg in reversed(chat_request.history):
87
- if msg.role == "user" and (
88
- msg.searchResults is None or not msg.searchResults
89
- ):
90
- msg.content = new_content
91
- return True
92
- return False
93
-
94
  last_query = get_last_user_message(request)
95
  search_result = None
96
 
@@ -126,4 +197,4 @@ async def chat(
126
  logger.error(
127
  f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10
128
  )
129
- return {"error": str(e)}
 
1
+ import json
2
  import logging
3
  import os
4
+ from typing import Annotated, AsyncGenerator, Optional
5
  from uuid import UUID
6
 
7
+ from fastapi.responses import StreamingResponse
8
+
9
  from components.services.dataset import DatasetService
10
  from components.services.entity import EntityService
11
  from fastapi import APIRouter, Depends, HTTPException
 
45
  # TODO: унести в DI
46
  llm_api = DeepInfraApi(params=llm_params)
47
 
48
+ # TODO: Вынести
49
+ def get_last_user_message(chat_request: ChatRequest) -> Optional[Message]:
50
+ return next(
51
+ (
52
+ msg
53
+ for msg in reversed(chat_request.history)
54
+ if msg.role == "user"
55
+ and (msg.searchResults is None or not msg.searchResults)
56
+ ),
57
+ None,
58
+ )
59
+
60
+ def insert_search_results_to_message(
61
+ chat_request: ChatRequest, new_content: str
62
+ ) -> bool:
63
+ for msg in reversed(chat_request.history):
64
+ if msg.role == "user" and (
65
+ msg.searchResults is None or not msg.searchResults
66
+ ):
67
+ msg.content = new_content
68
+ return True
69
+ return False
70
+
71
+ async def sse_generator(request: ChatRequest, llm_api: DeepInfraApi, system_prompt: str,
72
+ predict_params: LlmPredictParams,
73
+ dataset_service: DatasetService,
74
+ entity_service: EntityService) -> AsyncGenerator[str, None]:
75
+ """
76
+ Генератор для стриминга ответа LLM через SSE.
77
+ """
78
+ # Обработка поиска
79
+ last_query = get_last_user_message(request)
80
+ if last_query:
81
+ dataset = dataset_service.get_current_dataset()
82
+ if dataset is None:
83
+ raise HTTPException(status_code=400, detail="Dataset not found")
84
+ _, scores, chunk_ids = entity_service.search_similar(last_query.content, dataset.id)
85
+ chunks = entity_service.chunk_repository.get_chunks_by_ids(chunk_ids)
86
+ text_chunks = entity_service.build_text(chunks, scores)
87
+ search_results_event = {
88
+ "event": "search_results",
89
+ "data": f"\n<search-results>\n{text_chunks}\n</search-results>"
90
+ }
91
+ yield f"data: {json.dumps(search_results_event, ensure_ascii=False)}\n\n"
92
+
93
+ new_message = f'{last_query.content}\n<search-results>\n{text_chunks}\n</search-results>'
94
+ insert_search_results_to_message(request, new_message)
95
+
96
+ # Стриминг токенов ответа
97
+ async for token in llm_api.get_predict_chat_generator(request, system_prompt, predict_params):
98
+ token_event = {"event": "token", "data": token}
99
+ logger.info(f"Streaming token: {token}")
100
+ yield f"data: {json.dumps(token_event, ensure_ascii=False)}\n\n"
101
+
102
+ # Финальное событие
103
+ yield "data: {\"event\": \"done\"}\n\n"
104
+
105
+
106
+ @router.post("/chat/stream")
107
+ async def chat_stream(
108
+ request: ChatRequest,
109
+ config: Annotated[Configuration, Depends(DI.get_config)],
110
+ llm_api: Annotated[DeepInfraApi, Depends(DI.get_llm_service)],
111
+ prompt_service: Annotated[LlmPromptService, Depends(DI.get_llm_prompt_service)],
112
+ llm_config_service: Annotated[LLMConfigService, Depends(DI.get_llm_config_service)],
113
+ entity_service: Annotated[EntityService, Depends(DI.get_entity_service)],
114
+ dataset_service: Annotated[DatasetService, Depends(DI.get_dataset_service)],
115
+ ):
116
+ try:
117
+ p = llm_config_service.get_default()
118
+ system_prompt = prompt_service.get_default()
119
+
120
+ predict_params = LlmPredictParams(
121
+ temperature=p.temperature,
122
+ top_p=p.top_p,
123
+ min_p=p.min_p,
124
+ seed=p.seed,
125
+ frequency_penalty=p.frequency_penalty,
126
+ presence_penalty=p.presence_penalty,
127
+ n_predict=p.n_predict,
128
+ stop=[],
129
+ )
130
+
131
+ return StreamingResponse(
132
+ sse_generator(request, llm_api, system_prompt.text, predict_params, dataset_service, entity_service),
133
+ media_type="text/event-stream",
134
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}
135
+ )
136
+ except Exception as e:
137
+ logger.error(f"Error in SSE chat stream: {str(e)}", stack_info=True)
138
+ raise HTTPException(status_code=500, detail=str(e))
139
 
140
  @router.post("/chat")
141
  async def chat(
 
162
  stop=[],
163
  )
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  last_query = get_last_user_message(request)
166
  search_result = None
167
 
 
197
  logger.error(
198
  f"Error processing LLM request: {str(e)}", stack_info=True, stacklevel=10
199
  )
200
+ return {"error": str(e)}