muryshev's picture
update
308de05
raw
history blame
3.4 kB
import logging
from ntr_fileparser import ParsedDocument
from ntr_text_fragmentation import (
ChunkingStrategy,
LinkerEntity,
register_chunking_strategy,
register_entity,
DocumentAsEntity,
Chunk,
)
from components.llm.common import LlmPredictParams
from components.llm.deepinfra_api import DeepInfraApi
from components.llm.prompts import PROMPT_APPENDICES
from components.services.llm_config import LLMConfigService
logger = logging.getLogger(__name__)
APPENDICES_CHUNKER = 'appendices'
@register_entity
class Appendix(Chunk):
"""Сущность для хранения приложений"""
@register_chunking_strategy(APPENDICES_CHUNKER)
class AppendicesProcessor(ChunkingStrategy):
def __init__(
self,
llm_api: DeepInfraApi,
llm_config_service: LLMConfigService,
):
self.prompt = PROMPT_APPENDICES
self.llm_api = llm_api
p = llm_config_service.get_default()
self.llm_params = LlmPredictParams(
temperature=p.temperature,
top_p=p.top_p,
min_p=p.min_p,
seed=p.seed,
frequency_penalty=p.frequency_penalty,
presence_penalty=p.presence_penalty,
n_predict=p.n_predict,
)
def chunk(
self, document: ParsedDocument, doc_entity: DocumentAsEntity
) -> list[LinkerEntity]:
raise NotImplementedError(
f"{self.__class__.__name__} поддерживает только асинхронный вызов. "
"Используйте метод extract_async или другую стратегию."
)
async def chunk_async(
self, document: ParsedDocument, doc_entity: DocumentAsEntity
) -> list[LinkerEntity]:
text = ""
text += document.name + "\n"
text += "\n".join([p.text for p in document.paragraphs])
text += "\n".join([t.to_string() for t in document.tables])
prompt = self._format_prompt(text)
response = await self.llm_api.predict(prompt=prompt, system_prompt=None)
processed = self._postprocess_llm_response(response)
if processed is None:
return []
entity = Appendix(
text=processed,
in_search_text=processed,
number_in_relation=0,
groupper=APPENDICES_CHUNKER,
)
entity.owner_id = doc_entity.id
return [entity]
def _format_prompt(self, text: str) -> str:
return self.prompt.format(replace_me=text)
def _postprocess_llm_response(self, response: str | None) -> str | None:
if response is None:
return None
# Найти начало и конец текста в квадратных скобках
start = response.find('[')
end = response.find(']')
# Проверка, что найдена только одна пара скобок
if start == -1 or end == -1 or start >= end:
logger.warning(f"Некорректный формат ответа LLM: {response}")
return None
# Извлечь текст внутри скобок
extracted_text = response[start + 1 : end]
if extracted_text == '%%':
logging.info(f'Приложение признано бесполезным')
return None
return extracted_text