|
import os |
|
import shutil |
|
import json |
|
from datetime import date |
|
from huggingface_hub import hf_hub_download, HfApi |
|
from langchain_community.vectorstores.faiss import FAISS |
|
from langchain.docstore.document import Document |
|
import streamlit as st |
|
|
|
class HFVectorDB: |
|
def __init__(self, hf_repo_id, hf_token, local_index_dir="/tmp/vector_index", embedding_model=None): |
|
self.hf_repo_id = hf_repo_id |
|
self.hf_token = hf_token |
|
self.local_index_dir = local_index_dir |
|
self.embedding_model = embedding_model |
|
os.makedirs(self.local_index_dir, exist_ok=True) |
|
self.index = None |
|
|
|
self.api = HfApi(token=self.hf_token) |
|
|
|
self._download_index_files() |
|
self._load_index() |
|
|
|
def _download_index_files(self): |
|
try: |
|
|
|
faiss_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.faiss", repo_type="dataset") |
|
pkl_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.pkl", repo_type="dataset") |
|
|
|
shutil.copy(faiss_path, os.path.join(self.local_index_dir, "index.faiss")) |
|
shutil.copy(pkl_path, os.path.join(self.local_index_dir, "index.pkl")) |
|
st.sidebar.write(f"β
Downloaded FAISS index files from HF repo using HfApi to {self.local_index_dir}") |
|
except Exception as e: |
|
st.sidebar.write(f"β οΈ Could not download FAISS index files: {e}") |
|
|
|
def _load_index(self): |
|
try: |
|
|
|
if self.embedding_model is None: |
|
st.sidebar.error("Error: embedding_model must be provided to load a FAISS index.") |
|
self.index = None |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.index = FAISS.load_local(self.local_index_dir, self.embedding_model,allow_dangerous_deserialization=True) |
|
st.sidebar.write("β
Loaded FAISS index from local") |
|
except Exception as e: |
|
st.sidebar.write(f"βΉοΈ No local FAISS index found, starting with empty index {e}") |
|
self.index = None |
|
|
|
def save_index(self): |
|
if self.index is not None: |
|
self.index.save_local(self.local_index_dir) |
|
st.sidebar.write("β
Saved FAISS index locally") |
|
self._upload_index_files() |
|
else: |
|
st.sidebar.write("β οΈ No FAISS index to save") |
|
|
|
def _upload_index_files(self): |
|
try: |
|
|
|
self.api.upload_file( |
|
path_or_fileobj=os.path.join(self.local_index_dir, "index.faiss"), |
|
path_in_repo="index.faiss", |
|
repo_id=self.hf_repo_id, |
|
repo_type="dataset", |
|
commit_message="Update FAISS index.faiss" |
|
) |
|
|
|
self.api.upload_file( |
|
path_or_fileobj=os.path.join(self.local_index_dir, "index.pkl"), |
|
path_in_repo="index.pkl", |
|
repo_id=self.hf_repo_id, |
|
repo_type="dataset", |
|
commit_message="Update FAISS index.pkl" |
|
) |
|
st.sidebar.write("β
Uploaded FAISS index files to HF repo using HfApi") |
|
except Exception as e: |
|
st.error(f"β Error uploading FAISS index files: {e}") |
|
|
|
|
|
def add_documents(self, docs): |
|
if self.index is None: |
|
|
|
if self.embedding_model is None: |
|
st.sidebar.error("Error: embedding_model must be provided to create a new FAISS index.") |
|
return |
|
|
|
self.index = FAISS.from_documents(docs, self.embedding_model) |
|
st.sidebar.write("β
Created new FAISS index") |
|
else: |
|
self.index.add_documents(docs) |
|
st.sidebar.write("β
Added documents to FAISS index") |
|
self.save_index() |
|
|
|
def similarity_search(self, query, k=5): |
|
if self.index is None: |
|
st.sidebar.write("β οΈ No index found, returning empty results") |
|
return [] |
|
return self.index.similarity_search(query, k=k) |
|
|
|
def upload_top_picks_jsonl(self, path="top_picks.jsonl"): |
|
"""Uploads the top_picks.jsonl to the HF dataset repo""" |
|
try: |
|
self.api.upload_file( |
|
path_or_fileobj=path, |
|
path_in_repo="top_picks.jsonl", |
|
repo_id=self.hf_repo_id, |
|
repo_type="dataset", |
|
commit_message="Update top_picks.jsonl" |
|
) |
|
st.sidebar.write("β
Uploaded top_picks.jsonl to HF repo") |
|
except Exception as e: |
|
st.sidebar.error(f"β Error uploading top_picks.jsonl: {e}") |
|
|
|
|
|
|
|
def save_top_picks_json(vector_db,top_picks, date, path="top_picks.jsonl"): |
|
record = { |
|
"date": date.isoformat(), |
|
"top_picks": top_picks |
|
} |
|
vector_db.upload_top_picks_jsonl(path) |
|
st.sidebar.write(f"β
Saved top picks to {path}") |
|
|
|
def add_top_picks_to_vector_db(vector_db, top_picks, date): |
|
docs = [] |
|
for pick in top_picks: |
|
content = ( |
|
f"{pick['company']} ({pick['ticker']}):\n" |
|
f"Sentiment: {pick['sentiment']}\n" |
|
f"Critical News: {pick['critical_news']}\n" |
|
f"Impact: {pick['impact_summary']}\n" |
|
f"Action: {pick['action']}\n" |
|
f"Reason: {pick['reason']}" |
|
) |
|
metadata = { |
|
"ticker": pick["ticker"], |
|
"company": pick["company"], |
|
"sentiment": pick["sentiment"], |
|
"action": pick["action"], |
|
"date": date.isoformat() |
|
} |
|
docs.append(Document(page_content=content, metadata=metadata)) |
|
vector_db.add_documents(docs) |
|
st.sidebar.write("β
Added top picks to vector DB") |