File size: 7,979 Bytes
3e0b264 14c46c2 e7c776a 14c46c2 e7c776a f98ed18 831f2ed 3e0b264 14c46c2 e7c776a 14c46c2 3e0b264 14c46c2 ab5abc8 14c46c2 af5fb28 14c46c2 af5fb28 3e0b264 14c46c2 e7c776a af5fb28 e7c776a 6f73f4e 4698c74 6f73f4e d9c7a9a af5fb28 91713c8 af5fb28 14c46c2 3e0b264 14c46c2 af5fb28 14c46c2 3e0b264 af5fb28 3e0b264 14c46c2 e7c776a af5fb28 e7c776a 0f8eb9f 3e0b264 14c46c2 e7c776a af5fb28 e7c776a 14c46c2 af5fb28 14c46c2 af5fb28 14c46c2 af5fb28 14c46c2 ebdb41f b0051f3 ebdb41f 831f2ed b0051f3 af5fb28 ebdb41f 831f2ed ebdb41f 831f2ed ebdb41f 831f2ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import os
import shutil
import json
from datetime import date # Assuming you'll use this for the date object
from huggingface_hub import hf_hub_download, HfApi # Import HfApi
from langchain_community.vectorstores.faiss import FAISS
from langchain.docstore.document import Document # Import Document
import streamlit as st
import tempfile
class HFVectorDB:
def __init__(self, hf_repo_id, hf_token, local_index_dir="/tmp/vector_index", embedding_model=None):
self.hf_repo_id = hf_repo_id
self.hf_token = hf_token
self.local_index_dir = local_index_dir
self.embedding_model = embedding_model
os.makedirs(self.local_index_dir, exist_ok=True)
self.index = None
# Initialize HfApi
self.api = HfApi(token=self.hf_token)
# Download index files from HF repo (if exist)
self._download_index_files()
self._load_index()
def _download_index_files(self):
try:
# Using self.api.hf_hub_download for consistency
faiss_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.faiss", repo_type="dataset")
pkl_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.pkl", repo_type="dataset")
shutil.copy(faiss_path, os.path.join(self.local_index_dir, "index.faiss"))
shutil.copy(pkl_path, os.path.join(self.local_index_dir, "index.pkl"))
st.sidebar.write(f"✅ Downloaded FAISS index files from HF repo using HfApi to {self.local_index_dir}")
except Exception as e:
st.sidebar.write(f"⚠️ Could not download FAISS index files: {e}")
def _load_index(self):
try:
# Ensure embedding_model is provided when loading a FAISS index
if self.embedding_model is None:
st.sidebar.error("Error: embedding_model must be provided to load a FAISS index.")
self.index = None
return
# Add more detailed logging before the load attempt
# st.write(f"Attempting to load FAISS index from: {self.local_index_dir}")
# st.write(f"Contents of local_index_dir: {os.listdir(self.local_index_dir)}")
self.index = FAISS.load_local(self.local_index_dir, self.embedding_model,allow_dangerous_deserialization=True)
st.sidebar.write("✅ Loaded FAISS index from local")
except Exception as e:
st.sidebar.write(f"ℹ️ No local FAISS index found, starting with empty index {e}")
self.index = None
def save_index(self):
if self.index is not None:
self.index.save_local(self.local_index_dir)
st.sidebar.write("✅ Saved FAISS index locally")
self._upload_index_files()
else:
st.sidebar.write("⚠️ No FAISS index to save")
def _upload_index_files(self):
try:
# Upload index.faiss
self.api.upload_file(
path_or_fileobj=os.path.join(self.local_index_dir, "index.faiss"),
path_in_repo="index.faiss",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update FAISS index.faiss"
)
# Upload index.pkl
self.api.upload_file(
path_or_fileobj=os.path.join(self.local_index_dir, "index.pkl"),
path_in_repo="index.pkl",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update FAISS index.pkl"
)
st.sidebar.write("✅ Uploaded FAISS index files to HF repo using HfApi")
except Exception as e:
st.error(f"❌ Error uploading FAISS index files: {e}")
def add_documents(self, docs):
if self.index is None:
# Ensure embedding_model is provided when creating a new FAISS index
if self.embedding_model is None:
st.sidebar.error("Error: embedding_model must be provided to create a new FAISS index.")
return
self.index = FAISS.from_documents(docs, self.embedding_model)
st.sidebar.write("✅ Created new FAISS index")
else:
self.index.add_documents(docs)
st.sidebar.write("✅ Added documents to FAISS index")
self.save_index()
def similarity_search(self, query, k=5):
if self.index is None:
st.sidebar.write("⚠️ No index found, returning empty results")
return []
return self.index.similarity_search(query, k=k)
def upload_top_picks_jsonl(self, path="top_picks.jsonl"):
"""Uploads the top_picks.jsonl to the HF dataset repo"""
try:
self.api.upload_file(
path_or_fileobj=path,
path_in_repo="top_picks.jsonl",
repo_id=self.hf_repo_id,
repo_type="dataset",
commit_message="Update top_picks.jsonl"
)
st.sidebar.write("✅ Uploaded top_picks.jsonl to HF repo")
except Exception as e:
st.sidebar.error(f"❌ Error uploading top_picks.jsonl: {e}")
def save_top_picks_json(vector_db,top_picks, date, path="top_picks.jsonl"):
record = {
"date": date.isoformat(),
"top_picks": top_picks
}
# Step 1: Download existing top_picks.jsonl from HF
try:
tmp_file = hf_hub_download(
repo_id=vector_db.hf_repo_id,
filename="top_picks.jsonl",
repo_type="dataset",
token=vector_db.hf_token,
cache_dir=tempfile.mkdtemp()
)
with open(tmp_file, "r") as f:
existing_lines = f.readlines()
except Exception as e:
st.sidebar.warning(f"⚠️ No existing top_picks.jsonl found or failed to download: {e}")
existing_lines = []
# Step 2: Parse and filter existing records
records = []
for line in existing_lines:
try:
rec = json.loads(line)
if rec["date"] != date.isoformat(): # Remove existing entry for the same date
records.append(rec)
except json.JSONDecodeError:
continue
# Step 3: Add current record
records.append(record)
# Step 4: Overwrite local top_picks.jsonl
with open(path, "w") as f:
for rec in records:
f.write(json.dumps(rec) + "\n")
vector_db.upload_top_picks_jsonl(path)
st.sidebar.write(f"✅ Saved top picks to {path}")
def add_top_picks_to_vector_db(vector_db, top_picks, date):
existing_docs = []
if vector_db.index is not None:
try:
existing_docs = vector_db.index.docstore._dict.values()
except Exception:
pass
existing_keys = set()
for doc in existing_docs:
key = (doc.metadata.get("ticker"), doc.metadata.get("date"))
existing_keys.add(key)
new_docs = []
for pick in top_picks:
key = (pick["ticker"], date.isoformat())
if key in existing_keys:
continue # Skip duplicates
content = (
f"{pick['company']} ({pick['ticker']}):\n"
f"Sentiment: {pick['sentiment']}\n"
f"Critical News: {pick['critical_news']}\n"
f"Impact: {pick['impact_summary']}\n"
f"Action: {pick['action']}\n"
f"Reason: {pick['reason']}"
)
metadata = {
"ticker": pick["ticker"],
"company": pick["company"],
"sentiment": pick["sentiment"],
"action": pick["action"],
"date": date.isoformat()
}
new_docs.append(Document(page_content=content, metadata=metadata))
if new_docs:
vector_db.add_documents(new_docs)
st.sidebar.write(f"✅ Added {len(new_docs)} new documents to vector DB")
else:
st.sidebar.write("ℹ️ No new documents to add to vector DB")
|