Spaces:
Runtime error
Runtime error
| import hashlib | |
| from datetime import datetime | |
| from multiprocessing.pool import ThreadPool | |
| from typing import List | |
| import requests | |
| from clickhouse_sqlalchemy import types, engines | |
| from langchain.schema.embeddings import Embeddings | |
| from sqlalchemy import Column, Text | |
| from streamlit.runtime.uploaded_file_manager import UploadedFile | |
| def parse_files(api_key, user_id, files: List[UploadedFile]): | |
| def parse_file(file: UploadedFile): | |
| headers = { | |
| "accept": "application/json", | |
| "unstructured-api-key": api_key, | |
| } | |
| data = {"strategy": "auto", "ocr_languages": ["eng"]} | |
| file_hash = hashlib.sha256(file.read()).hexdigest() | |
| file_data = {"files": (file.name, file.getvalue(), file.type)} | |
| response = requests.post( | |
| url="https://api.unstructured.io/general/v0/general", | |
| headers=headers, | |
| data=data, | |
| files=file_data | |
| ) | |
| json_response = response.json() | |
| if response.status_code != 200: | |
| raise ValueError(str(json_response)) | |
| texts = [ | |
| { | |
| "text": t["text"], | |
| "file_name": t["metadata"]["filename"], | |
| "entity_id": hashlib.sha256( | |
| (file_hash + t["text"]).encode() | |
| ).hexdigest(), | |
| "user_id": user_id, | |
| "created_by": datetime.now(), | |
| } | |
| for t in json_response | |
| if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10 | |
| ] | |
| return texts | |
| with ThreadPool(8) as p: | |
| rows = [] | |
| for r in p.imap_unordered(parse_file, files): | |
| rows.extend(r) | |
| return rows | |
| def extract_embedding(embeddings: Embeddings, texts): | |
| if len(texts) > 0: | |
| embeddings = embeddings.embed_documents( | |
| [t["text"] for _, t in enumerate(texts)]) | |
| for i, _ in enumerate(texts): | |
| texts[i]["vector"] = embeddings[i] | |
| return texts | |
| raise ValueError("No texts extracted!") | |
| def create_message_history_table(table_name: str, base_class): | |
| class Message(base_class): | |
| __tablename__ = table_name | |
| id = Column(types.Float64) | |
| session_id = Column(Text) | |
| user_id = Column(Text) | |
| msg_id = Column(Text, primary_key=True) | |
| type = Column(Text) | |
| # should be additions, formal developer mistake spell it. | |
| addtionals = Column(Text) | |
| message = Column(Text) | |
| __table_args__ = ( | |
| engines.MergeTree( | |
| partition_by='session_id', | |
| order_by=('id', 'msg_id') | |
| ), | |
| {'comment': 'Store Chat History'} | |
| ) | |
| return Message | |
| def create_session_table(table_name: str, DynamicBase): | |
| class Session(DynamicBase): | |
| __tablename__ = table_name | |
| user_id = Column(Text) | |
| session_id = Column(Text, primary_key=True) | |
| system_prompt = Column(Text) | |
| # represent create time. | |
| create_by = Column(types.DateTime) | |
| # should be additions, formal developer mistake spell it. | |
| additionals = Column(Text) | |
| __table_args__ = ( | |
| engines.MergeTree(order_by=session_id), | |
| {'comment': 'Store Session and Prompts'} | |
| ) | |
| return Session | |