StockMarketInsights / stock_vector_db.py
rajat5ranjan's picture
updated data
831f2ed
import os
import shutil
import json
from datetime import date # Assuming you'll use this for the date object
from huggingface_hub import hf_hub_download, HfApi # Import HfApi
from langchain_community.vectorstores.faiss import FAISS
from langchain.docstore.document import Document # Import Document
import streamlit as st
import tempfile
class HFVectorDB:
def __init__(self, hf_repo_id, hf_token, local_index_dir="/tmp/vector_index", embedding_model=None):
self.hf_repo_id = hf_repo_id
self.hf_token = hf_token
self.local_index_dir = local_index_dir
self.embedding_model = embedding_model
os.makedirs(self.local_index_dir, exist_ok=True)
self.index = None
# Initialize HfApi
self.api = HfApi(token=self.hf_token)
# Download index files from HF repo (if exist)
self._download_index_files()
self._load_index()
def _download_index_files(self):
try:
# Using self.api.hf_hub_download for consistency
faiss_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.faiss", repo_type="dataset")
pkl_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.pkl", repo_type="dataset")
shutil.copy(faiss_path, os.path.join(self.local_index_dir, "index.faiss"))
shutil.copy(pkl_path, os.path.join(self.local_index_dir, "index.pkl"))
st.sidebar.write(f"βœ… Downloaded FAISS index files from HF repo using HfApi to {self.local_index_dir}")
except Exception as e:
st.sidebar.write(f"⚠️ Could not download FAISS index files: {e}")
def _load_index(self):
try:
# Ensure embedding_model is provided when loading a FAISS index
if self.embedding_model is None:
st.sidebar.error("Error: embedding_model must be provided to load a FAISS index.")
self.index = None
return
# Add more detailed logging before the load attempt
# st.write(f"Attempting to load FAISS index from: {self.local_index_dir}")
# st.write(f"Contents of local_index_dir: {os.listdir(self.local_index_dir)}")
self.index = FAISS.load_local(self.local_index_dir, self.embedding_model,allow_dangerous_deserialization=True)
st.sidebar.write("βœ… Loaded FAISS index from local")
except Exception as e:
st.sidebar.write(f"ℹ️ No local FAISS index found, starting with empty index {e}")
self.index = None
def save_index(self):
if self.index is not None:
self.index.save_local(self.local_index_dir)
st.sidebar.write("βœ… Saved FAISS index locally")
self._upload_index_files()
else:
st.sidebar.write("⚠️ No FAISS index to save")
def _upload_index_files(self):
try:
# Upload index.faiss
self.api.upload_file(
path_or_fileobj=os.path.join(self.local_index_dir, "index.faiss"),
path_in_repo="index.faiss",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update FAISS index.faiss"
)
# Upload index.pkl
self.api.upload_file(
path_or_fileobj=os.path.join(self.local_index_dir, "index.pkl"),
path_in_repo="index.pkl",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update FAISS index.pkl"
)
st.sidebar.write("βœ… Uploaded FAISS index files to HF repo using HfApi")
except Exception as e:
st.error(f"❌ Error uploading FAISS index files: {e}")
def add_documents(self, docs):
if self.index is None:
# Ensure embedding_model is provided when creating a new FAISS index
if self.embedding_model is None:
st.sidebar.error("Error: embedding_model must be provided to create a new FAISS index.")
return
self.index = FAISS.from_documents(docs, self.embedding_model)
st.sidebar.write("βœ… Created new FAISS index")
else:
self.index.add_documents(docs)
st.sidebar.write("βœ… Added documents to FAISS index")
self.save_index()
def similarity_search(self, query, k=5):
if self.index is None:
st.sidebar.write("⚠️ No index found, returning empty results")
return []
return self.index.similarity_search(query, k=k)
def upload_top_picks_jsonl(self, path="top_picks.jsonl"):
"""Uploads the top_picks.jsonl to the HF dataset repo"""
try:
self.api.upload_file(
path_or_fileobj=path,
path_in_repo="top_picks.jsonl",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update top_picks.jsonl"
)
st.sidebar.write("βœ… Uploaded top_picks.jsonl to HF repo")
except Exception as e:
st.sidebar.error(f"❌ Error uploading top_picks.jsonl: {e}")
def save_top_picks_json(vector_db,top_picks, date, path="top_picks.jsonl"):
record = {
"date": date.isoformat(),
"top_picks": top_picks
}
# Step 1: Download existing top_picks.jsonl from HF
try:
tmp_file = hf_hub_download(
repo_id=vector_db.hf_repo_id,
filename="top_picks.jsonl",
repo_type="dataset",
token=vector_db.hf_token,
cache_dir=tempfile.mkdtemp()
)
with open(tmp_file, "r") as f:
existing_lines = f.readlines()
except Exception as e:
st.sidebar.warning(f"⚠️ No existing top_picks.jsonl found or failed to download: {e}")
existing_lines = []
# Step 2: Parse and filter existing records
records = []
for line in existing_lines:
try:
rec = json.loads(line)
if rec["date"] != date.isoformat(): # Remove existing entry for the same date
records.append(rec)
except json.JSONDecodeError:
continue
# Step 3: Add current record
records.append(record)
# Step 4: Overwrite local top_picks.jsonl
with open(path, "w") as f:
for rec in records:
f.write(json.dumps(rec) + "\n")
vector_db.upload_top_picks_jsonl(path)
st.sidebar.write(f"βœ… Saved top picks to {path}")
def add_top_picks_to_vector_db(vector_db, top_picks, date):
existing_docs = []
if vector_db.index is not None:
try:
existing_docs = vector_db.index.docstore._dict.values()
except Exception:
pass
existing_keys = set()
for doc in existing_docs:
key = (doc.metadata.get("ticker"), doc.metadata.get("date"))
existing_keys.add(key)
new_docs = []
for pick in top_picks:
key = (pick["ticker"], date.isoformat())
if key in existing_keys:
continue # Skip duplicates
content = (
f"{pick['company']} ({pick['ticker']}):\n"
f"Sentiment: {pick['sentiment']}\n"
f"Critical News: {pick['critical_news']}\n"
f"Impact: {pick['impact_summary']}\n"
f"Action: {pick['action']}\n"
f"Reason: {pick['reason']}"
)
metadata = {
"ticker": pick["ticker"],
"company": pick["company"],
"sentiment": pick["sentiment"],
"action": pick["action"],
"date": date.isoformat()
}
new_docs.append(Document(page_content=content, metadata=metadata))
if new_docs:
vector_db.add_documents(new_docs)
st.sidebar.write(f"βœ… Added {len(new_docs)} new documents to vector DB")
else:
st.sidebar.write("ℹ️ No new documents to add to vector DB")