import pandas as pd import hashlib import requests from typing import List from datetime import datetime from langchain.schema.embeddings import Embeddings from streamlit.runtime.uploaded_file_manager import UploadedFile from clickhouse_connect import get_client from multiprocessing.pool import ThreadPool from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings parser_url = "https://api.unstructured.io/general/v0/general" def parse_files(api_key, user_id, files: List[UploadedFile], collection="default"): def parse_file(file: UploadedFile): headers = { "accept": "application/json", "unstructured-api-key": api_key, } data = {"strategy": "auto", "ocr_languages": ["eng"]} file_hash = hashlib.sha256(file.read()).hexdigest() file_data = {"files": (file.name, file.getvalue(), file.type)} response = requests.post( parser_url, headers=headers, data=data, files=file_data ) json_response = response.json() if response.status_code != 200: raise ValueError(str(json_response)) texts = [ { "text": t["text"], "file_name": t["metadata"]["filename"], "entity_id": hashlib.sha256((file_hash + t["text"]).encode()).hexdigest(), "user_id": user_id, "collection_id": collection, "created_by": datetime.now(), } for t in json_response if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 ] return texts with ThreadPool(8) as p: rows = [] for r in map(parse_file, files): rows.extend(r) return rows def extract_embedding(embeddings: Embeddings, texts): if len(texts) > 0: embs = embeddings.embed_documents([t["text"] for _, t in enumerate(texts)]) for i, _ in enumerate(texts): texts[i]["vector"] = embs[i] return texts raise ValueError("No texts extracted!") class PrivateKnowledgeBase: def __init__( self, host, port, username, password, embedding: Embeddings, parser_api_key, db="chat", kb_table="private_kb", ) -> None: super().__init__() schema_ = f""" CREATE TABLE IF NOT EXISTS {db}.{kb_table}( entity_id String, file_name String, text String, user_id String, collection_id String, created_by DateTime, vector Array(Float32), CONSTRAINT cons_vec_len CHECK length(vector) = 768, VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine') ) ENGINE = ReplacingMergeTree ORDER BY entity_id """ config = MyScaleSettings( host=host, port=port, username=username, password=password, database=db, table=kb_table, ) client = get_client( host=config.host, port=config.port, username=config.username, password=config.password, ) client.command("SET allow_experimental_object_type=1") client.command(schema_) self.parser_api_key = parser_api_key self.vstore = MyScaleWithoutJSON( embedding=embedding, config=config, must_have_cols=["file_name", "text", "create_by"], ) self.retriever = self.vstore.as_retriever() def list_files(self, user_id): query = f""" SELECT DISTINCT file_name FROM {self.vstore.config.database}.{self.vstore.config.table} WHERE user_id = '{user_id}' """ return [r for r in self.vstore.client.query(query).named_results()] def add_by_file( self, user_id, files: List[UploadedFile], collection="default", **kwargs ): data = parse_files(self.parser_api_key, user_id, files, collection=collection) data = extract_embedding(self.vstore.embeddings, data) self.vstore.client.insert_df( self.vstore.config.table, pd.DataFrame(data), database=self.vstore.config.database, ) def clear(self, user_id): self.vstore.client.command( f"DELETE FROM {self.vstore.config.database}.{self.vstore.config.table} " f"WHERE user_id='{user_id}'" ) def _get_relevant_documents(self, query, *args, **kwargs): return self.retriever._get_relevant_documents(query, *args, **kwargs) async def _aget_relevant_documents(self, *args, **kwargs): return self.retriever._aget_relevant_documents(*args, **kwargs)