|
import os |
|
import shutil |
|
import json |
|
from datetime import date |
|
from huggingface_hub import hf_hub_download, HfApi |
|
from langchain_community.vectorstores.faiss import FAISS |
|
from langchain.docstore.document import Document |
|
import streamlit as st |
|
import tempfile |
|
|
|
class HFVectorDB: |
|
def __init__(self, hf_repo_id, hf_token, local_index_dir="/tmp/vector_index", embedding_model=None): |
|
self.hf_repo_id = hf_repo_id |
|
self.hf_token = hf_token |
|
self.local_index_dir = local_index_dir |
|
self.embedding_model = embedding_model |
|
os.makedirs(self.local_index_dir, exist_ok=True) |
|
self.index = None |
|
|
|
self.api = HfApi(token=self.hf_token) |
|
|
|
self._download_index_files() |
|
self._load_index() |
|
|
|
def _download_index_files(self): |
|
try: |
|
|
|
faiss_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.faiss", repo_type="dataset") |
|
pkl_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.pkl", repo_type="dataset") |
|
|
|
shutil.copy(faiss_path, os.path.join(self.local_index_dir, "index.faiss")) |
|
shutil.copy(pkl_path, os.path.join(self.local_index_dir, "index.pkl")) |
|
st.sidebar.write(f"β
Downloaded FAISS index files from HF repo using HfApi to {self.local_index_dir}") |
|
except Exception as e: |
|
st.sidebar.write(f"β οΈ Could not download FAISS index files: {e}") |
|
|
|
def _load_index(self): |
|
try: |
|
|
|
if self.embedding_model is None: |
|
st.sidebar.error("Error: embedding_model must be provided to load a FAISS index.") |
|
self.index = None |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.index = FAISS.load_local(self.local_index_dir, self.embedding_model,allow_dangerous_deserialization=True) |
|
st.sidebar.write("β
Loaded FAISS index from local") |
|
except Exception as e: |
|
st.sidebar.write(f"βΉοΈ No local FAISS index found, starting with empty index {e}") |
|
self.index = None |
|
|
|
def save_index(self): |
|
if self.index is not None: |
|
self.index.save_local(self.local_index_dir) |
|
st.sidebar.write("β
Saved FAISS index locally") |
|
self._upload_index_files() |
|
else: |
|
st.sidebar.write("β οΈ No FAISS index to save") |
|
|
|
def _upload_index_files(self): |
|
try: |
|
|
|
self.api.upload_file( |
|
path_or_fileobj=os.path.join(self.local_index_dir, "index.faiss"), |
|
path_in_repo="index.faiss", |
|
repo_id=self.hf_repo_id, |
|
repo_type="dataset", |
|
commit_message="Update FAISS index.faiss" |
|
) |
|
|
|
self.api.upload_file( |
|
path_or_fileobj=os.path.join(self.local_index_dir, "index.pkl"), |
|
path_in_repo="index.pkl", |
|
repo_id=self.hf_repo_id, |
|
repo_type="dataset", |
|
commit_message="Update FAISS index.pkl" |
|
) |
|
st.sidebar.write("β
Uploaded FAISS index files to HF repo using HfApi") |
|
except Exception as e: |
|
st.error(f"β Error uploading FAISS index files: {e}") |
|
|
|
|
|
def add_documents(self, docs): |
|
if self.index is None: |
|
|
|
if self.embedding_model is None: |
|
st.sidebar.error("Error: embedding_model must be provided to create a new FAISS index.") |
|
return |
|
|
|
self.index = FAISS.from_documents(docs, self.embedding_model) |
|
st.sidebar.write("β
Created new FAISS index") |
|
else: |
|
self.index.add_documents(docs) |
|
st.sidebar.write("β
Added documents to FAISS index") |
|
self.save_index() |
|
|
|
def similarity_search(self, query, k=5): |
|
if self.index is None: |
|
st.sidebar.write("β οΈ No index found, returning empty results") |
|
return [] |
|
return self.index.similarity_search(query, k=k) |
|
|
|
def upload_top_picks_jsonl(self, path="top_picks.jsonl"): |
|
"""Uploads the top_picks.jsonl to the HF dataset repo""" |
|
try: |
|
self.api.upload_file( |
|
path_or_fileobj=path, |
|
path_in_repo="top_picks.jsonl", |
|
repo_id=self.hf_repo_id, |
|
repo_type="dataset", |
|
commit_message="Update top_picks.jsonl" |
|
) |
|
st.sidebar.write("β
Uploaded top_picks.jsonl to HF repo") |
|
except Exception as e: |
|
st.sidebar.error(f"β Error uploading top_picks.jsonl: {e}") |
|
|
|
|
|
|
|
def save_top_picks_json(vector_db,top_picks, date, path="top_picks.jsonl"): |
|
record = { |
|
"date": date.isoformat(), |
|
"top_picks": top_picks |
|
} |
|
|
|
|
|
try: |
|
tmp_file = hf_hub_download( |
|
repo_id=vector_db.hf_repo_id, |
|
filename="top_picks.jsonl", |
|
repo_type="dataset", |
|
token=vector_db.hf_token, |
|
cache_dir=tempfile.mkdtemp() |
|
) |
|
with open(tmp_file, "r") as f: |
|
existing_lines = f.readlines() |
|
except Exception as e: |
|
st.sidebar.warning(f"β οΈ No existing top_picks.jsonl found or failed to download: {e}") |
|
existing_lines = [] |
|
|
|
|
|
records = [] |
|
for line in existing_lines: |
|
try: |
|
rec = json.loads(line) |
|
if rec["date"] != date.isoformat(): |
|
records.append(rec) |
|
except json.JSONDecodeError: |
|
continue |
|
|
|
|
|
records.append(record) |
|
|
|
|
|
with open(path, "w") as f: |
|
for rec in records: |
|
f.write(json.dumps(rec) + "\n") |
|
|
|
vector_db.upload_top_picks_jsonl(path) |
|
st.sidebar.write(f"β
Saved top picks to {path}") |
|
|
|
def add_top_picks_to_vector_db(vector_db, top_picks, date): |
|
existing_docs = [] |
|
if vector_db.index is not None: |
|
try: |
|
existing_docs = vector_db.index.docstore._dict.values() |
|
except Exception: |
|
pass |
|
|
|
existing_keys = set() |
|
for doc in existing_docs: |
|
key = (doc.metadata.get("ticker"), doc.metadata.get("date")) |
|
existing_keys.add(key) |
|
|
|
new_docs = [] |
|
for pick in top_picks: |
|
key = (pick["ticker"], date.isoformat()) |
|
if key in existing_keys: |
|
continue |
|
|
|
content = ( |
|
f"{pick['company']} ({pick['ticker']}):\n" |
|
f"Sentiment: {pick['sentiment']}\n" |
|
f"Critical News: {pick['critical_news']}\n" |
|
f"Impact: {pick['impact_summary']}\n" |
|
f"Action: {pick['action']}\n" |
|
f"Reason: {pick['reason']}" |
|
) |
|
metadata = { |
|
"ticker": pick["ticker"], |
|
"company": pick["company"], |
|
"sentiment": pick["sentiment"], |
|
"action": pick["action"], |
|
"date": date.isoformat() |
|
} |
|
new_docs.append(Document(page_content=content, metadata=metadata)) |
|
|
|
if new_docs: |
|
vector_db.add_documents(new_docs) |
|
st.sidebar.write(f"β
Added {len(new_docs)} new documents to vector DB") |
|
else: |
|
st.sidebar.write("βΉοΈ No new documents to add to vector DB") |
|
|