StockMarketInsights / stock_vector_db.py
rajat5ranjan's picture
Update stock_vector_db.py
b0051f3 verified
raw
history blame
6.27 kB
import os
import shutil
import json
from datetime import date # Assuming you'll use this for the date object
from huggingface_hub import hf_hub_download, HfApi # Import HfApi
from langchain_community.vectorstores.faiss import FAISS
from langchain.docstore.document import Document # Import Document
import streamlit as st
class HFVectorDB:
def __init__(self, hf_repo_id, hf_token, local_index_dir="/tmp/vector_index", embedding_model=None):
self.hf_repo_id = hf_repo_id
self.hf_token = hf_token
self.local_index_dir = local_index_dir
self.embedding_model = embedding_model
os.makedirs(self.local_index_dir, exist_ok=True)
self.index = None
# Initialize HfApi
self.api = HfApi(token=self.hf_token)
# Download index files from HF repo (if exist)
self._download_index_files()
self._load_index()
def _download_index_files(self):
try:
# Using self.api.hf_hub_download for consistency
faiss_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.faiss", repo_type="dataset")
pkl_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.pkl", repo_type="dataset")
shutil.copy(faiss_path, os.path.join(self.local_index_dir, "index.faiss"))
shutil.copy(pkl_path, os.path.join(self.local_index_dir, "index.pkl"))
st.sidebar.write(f"βœ… Downloaded FAISS index files from HF repo using HfApi to {self.local_index_dir}")
except Exception as e:
st.sidebar.write(f"⚠️ Could not download FAISS index files: {e}")
def _load_index(self):
try:
# Ensure embedding_model is provided when loading a FAISS index
if self.embedding_model is None:
st.sidebar.error("Error: embedding_model must be provided to load a FAISS index.")
self.index = None
return
# Add more detailed logging before the load attempt
# st.write(f"Attempting to load FAISS index from: {self.local_index_dir}")
# st.write(f"Contents of local_index_dir: {os.listdir(self.local_index_dir)}")
self.index = FAISS.load_local(self.local_index_dir, self.embedding_model,allow_dangerous_deserialization=True)
st.sidebar.write("βœ… Loaded FAISS index from local")
except Exception as e:
st.sidebar.write(f"ℹ️ No local FAISS index found, starting with empty index {e}")
self.index = None
def save_index(self):
if self.index is not None:
self.index.save_local(self.local_index_dir)
st.sidebar.write("βœ… Saved FAISS index locally")
self._upload_index_files()
else:
st.sidebar.write("⚠️ No FAISS index to save")
def _upload_index_files(self):
try:
# Upload index.faiss
self.api.upload_file(
path_or_fileobj=os.path.join(self.local_index_dir, "index.faiss"),
path_in_repo="index.faiss",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update FAISS index.faiss"
)
# Upload index.pkl
self.api.upload_file(
path_or_fileobj=os.path.join(self.local_index_dir, "index.pkl"),
path_in_repo="index.pkl",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update FAISS index.pkl"
)
st.sidebar.write("βœ… Uploaded FAISS index files to HF repo using HfApi")
except Exception as e:
st.error(f"❌ Error uploading FAISS index files: {e}")
def add_documents(self, docs):
if self.index is None:
# Ensure embedding_model is provided when creating a new FAISS index
if self.embedding_model is None:
st.sidebar.error("Error: embedding_model must be provided to create a new FAISS index.")
return
self.index = FAISS.from_documents(docs, self.embedding_model)
st.sidebar.write("βœ… Created new FAISS index")
else:
self.index.add_documents(docs)
st.sidebar.write("βœ… Added documents to FAISS index")
self.save_index()
def similarity_search(self, query, k=5):
if self.index is None:
st.sidebar.write("⚠️ No index found, returning empty results")
return []
return self.index.similarity_search(query, k=k)
def upload_top_picks_jsonl(self, path="top_picks.jsonl"):
"""Uploads the top_picks.jsonl to the HF dataset repo"""
try:
self.api.upload_file(
path_or_fileobj=path,
path_in_repo="top_picks.jsonl",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update top_picks.jsonl"
)
st.sidebar.write("βœ… Uploaded top_picks.jsonl to HF repo")
except Exception as e:
st.sidebar.error(f"❌ Error uploading top_picks.jsonl: {e}")
def save_top_picks_json(vector_db,top_picks, date, path="top_picks.jsonl"):
record = {
"date": date.isoformat(),
"top_picks": top_picks
}
vector_db.upload_top_picks_jsonl(path)
st.sidebar.write(f"βœ… Saved top picks to {path}")
def add_top_picks_to_vector_db(vector_db, top_picks, date):
docs = []
for pick in top_picks:
content = (
f"{pick['company']} ({pick['ticker']}):\n"
f"Sentiment: {pick['sentiment']}\n"
f"Critical News: {pick['critical_news']}\n"
f"Impact: {pick['impact_summary']}\n"
f"Action: {pick['action']}\n"
f"Reason: {pick['reason']}"
)
metadata = {
"ticker": pick["ticker"],
"company": pick["company"],
"sentiment": pick["sentiment"],
"action": pick["action"],
"date": date.isoformat()
}
docs.append(Document(page_content=content, metadata=metadata))
vector_db.add_documents(docs)
st.sidebar.write("βœ… Added top picks to vector DB")