Project / core /pipeline /chataipipeline.py
puzan789's picture
updated
ad87194
raw
history blame
3.29 kB
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import PromptTemplate
from core.prompts.custom_prompts import _custom_prompts
from core.services.answer_query.answerquery import AnswerQuery
from core.services.document.add_document import AddDocument
from core.services.embeddings.Qdrant_BM25_embedding import qdrant_bm25_embedding
from core.services.embeddings.jina_embeddings import jina_embedding
from core.services.get_links.web_scraper import WebScraper
from core.services.ocr.replicate_ocr.replicate_ocr import ReplicateOCR as OCRService
from core.services.pdf_extraction.image_pdf.image_pdf_text_extraction import get_text_from_image_pdf
from core.services.pdf_extraction.text_pdf.text_pdf_extraction import extract_text_from_pdf
from core.services.website_url.text_extraction_urlsnew import WebScrapertext
from core.utils.utils import json_parser
class ChatAIPipeline:
def __init__(self):
prompt_template = _custom_prompts["RAG_ANSWER_PROMPT"]
follow_up_prompt_template = _custom_prompts["FOLLOW_UP_PROMPT"]
prompt = ChatPromptTemplate.from_template(prompt_template)
json_parser_ = json_parser()
follow_up_prompt = PromptTemplate(
template=follow_up_prompt_template,
input_variables=["context"],
partial_variables={"format_instructions": json_parser_.get_format_instructions()},
)
self.vector_embedding = jina_embedding()
self.sparse_embedding = qdrant_bm25_embedding()
self.add_document_service = AddDocument(self.vector_embedding, self.sparse_embedding)
self.answer_query_service = AnswerQuery(vector_embedding=self.vector_embedding,
sparse_embedding=self.sparse_embedding, prompt=prompt,
follow_up_prompt=follow_up_prompt, json_parser=json_parser_)
self.get_website_links = WebScraper()
self.ocr_service = OCRService()
self.web_text_extractor = WebScrapertext()
def add_document_(self, texts: list[tuple[str]], vectorstore: str):
return self.add_document_service.add_documents(texts=texts, vectorstore=vectorstore)
def answer_query_(self, query: str, vectorstore: str, llm_model: str = "llama-3.3-70b-versatile"):
output, follow_up_questions, source = self.answer_query_service.answer_query(query=query,
vectorstore=vectorstore,
llmModel=llm_model)
return output, follow_up_questions, source
def get_links_(self, url: str, timeout: int):
return self.get_website_links.get_links(url=url, timeout=timeout)
def image_pdf_text_extraction_(self, image_pdf: bytes):
return get_text_from_image_pdf(pdf_bytes=image_pdf)
def text_pdf_extraction_(self, pdf: str):
return extract_text_from_pdf(pdf_path=pdf)
def website_url_text_extraction_(self, url: str):
return self.web_text_extractor.extract_text_from_url(url=url)
def website_url_text_extraction_list_(self, urls: list):
return self.web_text_extractor.extract_text_from_urls(urls=urls)