StockMarketInsights / stock_vector_db.py
rajat5ranjan's picture
Update stock_vector_db.py
ebdb41f verified
raw
history blame
4.23 kB
import os
import shutil
from huggingface_hub import hf_hub_download, Repository
from langchain_community.vectorstores.faiss import FAISS
import streamlit as st
class HFVectorDB:
def __init__(self, hf_repo_id, hf_token, local_index_dir="/tmp/vector_index", embedding_model=None):
self.hf_repo_id = hf_repo_id
self.hf_token = hf_token
self.local_index_dir = local_index_dir
self.embedding_model = embedding_model
os.makedirs(self.local_index_dir, exist_ok=True)
self.index = None
# Download index files from HF repo (if exist)
self._download_index_files()
self._load_index()
def _download_index_files(self):
try:
faiss_path = hf_hub_download(repo_id=self.hf_repo_id, filename="index.faiss", use_auth_token=self.hf_token)
pkl_path = hf_hub_download(repo_id=self.hf_repo_id, filename="index.pkl", use_auth_token=self.hf_token)
shutil.copy(faiss_path, os.path.join(self.local_index_dir, "index.faiss"))
shutil.copy(pkl_path, os.path.join(self.local_index_dir, "index.pkl"))
st.write("βœ… Downloaded FAISS index files from HF repo")
except Exception as e:
st.write(f"⚠️ Could not download FAISS index files: {e}")
def _load_index(self):
try:
self.index = FAISS.load_local(self.local_index_dir, self.embedding_model)
st.write("βœ… Loaded FAISS index from local")
except Exception:
st.write("ℹ️ No local FAISS index found, starting with empty index")
self.index = None
def save_index(self):
if self.index is not None:
self.index.save_local(self.local_index_dir)
st.write("βœ… Saved FAISS index locally")
self._upload_index_files()
else:
st.write("⚠️ No FAISS index to save")
def _upload_index_files(self):
repo_local_path = "/tmp/hf_dataset_repo_clone"
if os.path.exists(repo_local_path):
shutil.rmtree(repo_local_path)
repo = Repository(local_dir=repo_local_path, clone_from=self.hf_repo_id, use_auth_token=self.hf_token)
shutil.copy(os.path.join(self.local_index_dir, "index.faiss"), os.path.join(repo_local_path, "index.faiss"))
shutil.copy(os.path.join(self.local_index_dir, "index.pkl"), os.path.join(repo_local_path, "index.pkl"))
repo.git_add(auto_lfs_track=True)
repo.git_commit("Update FAISS index files")
repo.git_push()
st.write("βœ… Uploaded FAISS index files to HF repo")
def add_documents(self, docs):
if self.index is None:
self.index = FAISS.from_documents(docs, self.embedding_model)
st.write("βœ… Created new FAISS index")
else:
self.index.add_documents(docs)
st.write("βœ… Added documents to FAISS index")
self.save_index()
def similarity_search(self, query, k=5):
if self.index is None:
st.write("⚠️ No index found, returning empty results")
return []
return self.index.similarity_search(query, k=k)
def save_top_picks_json(top_picks, date, path="top_picks.jsonl"):
record = {
"date": date.isoformat(),
"top_picks": top_picks
}
with open(path, "a") as f:
f.write(json.dumps(record) + "\n")
st.write(f"βœ… Saved top picks to {path}")
def add_top_picks_to_vector_db(vector_db, top_picks, date):
docs = []
for pick in top_picks:
content = (
f"{pick['company']} ({pick['ticker']}):\n"
f"Sentiment: {pick['sentiment']}\n"
f"Critical News: {pick['critical_news']}\n"
f"Impact: {pick['impact_summary']}\n"
f"Action: {pick['action']}\n"
f"Reason: {pick['reason']}"
)
metadata = {
"ticker": pick["ticker"],
"company": pick["company"],
"sentiment": pick["sentiment"],
"action": pick["action"],
"date": date.isoformat()
}
docs.append(Document(page_content=content, metadata=metadata))
vector_db.add_documents(docs)
st.write("βœ… Added top picks to vector DB")