rajat5ranjan commited on
Commit
14c46c2
Β·
verified Β·
1 Parent(s): cd61a3c

Update stock_vector_db.py

Browse files
Files changed (1) hide show
  1. stock_vector_db.py +61 -91
stock_vector_db.py CHANGED
@@ -1,106 +1,76 @@
1
  import os
2
- import json
3
- from typing import List, Optional
4
- from datetime import datetime
5
- from langchain.vectorstores import FAISS
6
- from langchain.embeddings import HuggingFaceEmbeddings
7
- from langchain.schema import Document
8
  import streamlit as st
9
 
10
- class StockVectorDB:
11
- def __init__(self, index_path: str = "vector_index", log_path: str = "vector_log.json", embedding_model: Optional[str] = None):
12
- self.index_path = index_path
13
- self.log_path = log_path
14
- model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
15
- self.embedding_model = HuggingFaceEmbeddings(model_name=model)
16
 
17
- # Load FAISS index if it exists
18
- if os.path.exists(index_path):
19
- st.write(f"πŸ” Loading existing FAISS index from '{index_path}'")
20
- self.index = FAISS.load_local(index_path, self.embedding_model)
21
- else:
22
- st.write(f"πŸ†• No FAISS index found. Will create when documents are added.")
23
- self.index = None # Delay creation
24
-
25
- # Load or initialize log
26
- if os.path.exists(self.log_path):
27
- with open(self.log_path, "r") as f:
28
- try:
29
- self.log_data = json.load(f)
30
- except json.JSONDecodeError:
31
- self.log_data = []
32
- else:
33
- self.log_data = []
34
 
35
- def store_top_picks(self, top_picks: List[dict], date: Optional[datetime] = None):
36
- date = date or datetime.now()
37
- formatted_date = date.strftime("%Y-%m-%d")
38
- docs = []
39
 
40
- for stock in top_picks:
41
- try:
42
- content = f"{stock['ticker']} {stock['company']} is {stock['sentiment']} due to: {stock['critical_news']}. Impact: {stock['impact_summary']}. Action: {stock['action']}. Reason: {stock['reason']}"
43
- metadata = {
44
- "date": formatted_date,
45
- "ticker": stock["ticker"],
46
- "company": stock["company"],
47
- "sentiment": stock["sentiment"],
48
- "action": stock["action"]
49
- }
50
- docs.append(Document(page_content=content, metadata=metadata))
51
 
52
- self.log_data.append({
53
- "ticker": stock["ticker"],
54
- "company": stock["company"],
55
- "date": formatted_date,
56
- "sentiment": stock["sentiment"],
57
- "action": stock["action"],
58
- "reason": stock["reason"]
59
- })
60
- except KeyError as e:
61
- st.write(f"❌ Skipping malformed stock data: {e}")
62
- continue
63
 
64
- if docs:
65
- if self.index is None:
66
- self.index = FAISS.from_documents(docs, self.embedding_model)
67
- else:
68
- self.index.add_documents(docs)
69
-
70
- self.save_index()
71
- self.save_log()
72
- st.write(f"βœ… Stored {len(docs)} documents for {formatted_date}")
73
  else:
74
- st.write("⚠️ No valid documents to store.")
75
 
76
- def save_index(self):
77
- if self.index:
78
- self.index.save_local(self.index_path)
 
79
 
80
- def save_log(self):
81
- with open(self.log_path, "w") as f:
82
- json.dump(self.log_data, f, indent=2)
83
 
84
- def search(self, query: str, k: int = 5):
85
- if not self.index:
86
- st.write("⚠️ FAISS index is empty.")
87
- return []
88
 
89
- print(f"πŸ” Searching for: '{query}' (top {k})")
90
- results = self.index.similarity_search(query, k=k)
91
- for res in results:
92
- st.write(f"\nπŸ“Œ Ticker: {res.metadata.get('ticker')} | Sentiment: {res.metadata.get('sentiment')} | Date: {res.metadata.get('date')}")
93
- st.write(res.page_content)
94
- st.write("-" * 80)
95
- return results
96
 
97
- def backup(self, backup_dir: str = "vector_backups"):
98
- if not self.index:
99
- st.write("⚠️ No index to backup.")
100
- return
 
 
 
101
 
102
- os.makedirs(backup_dir, exist_ok=True)
103
- ts = datetime.now().strftime("%Y%m%d_%H%M%S")
104
- backup_path = os.path.join(backup_dir, f"vector_index_{ts}")
105
- self.index.save_local(backup_path)
106
- st.write(f"πŸ“¦ Backup saved to {backup_path}")
 
 
 
1
  import os
2
+ import shutil
3
+ from huggingface_hub import hf_hub_download, Repository
4
+ from langchain_community.vectorstores.faiss import FAISS
 
 
 
5
  import streamlit as st
6
 
7
+ class HFVectorDB:
8
+ def __init__(self, hf_repo_id, hf_token, local_index_dir="/tmp/vector_index", embedding_model=None):
9
+ self.hf_repo_id = hf_repo_id
10
+ self.hf_token = hf_token
11
+ self.local_index_dir = local_index_dir
12
+ self.embedding_model = embedding_model
13
 
14
+ os.makedirs(self.local_index_dir, exist_ok=True)
15
+ self.index = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # Download index files from HF repo (if exist)
18
+ self._download_index_files()
19
+ self._load_index()
 
20
 
21
+ def _download_index_files(self):
22
+ try:
23
+ faiss_path = hf_hub_download(repo_id=self.hf_repo_id, filename="index.faiss", use_auth_token=self.hf_token)
24
+ pkl_path = hf_hub_download(repo_id=self.hf_repo_id, filename="index.pkl", use_auth_token=self.hf_token)
25
+ shutil.copy(faiss_path, os.path.join(self.local_index_dir, "index.faiss"))
26
+ shutil.copy(pkl_path, os.path.join(self.local_index_dir, "index.pkl"))
27
+ st.write("βœ… Downloaded FAISS index files from HF repo")
28
+ except Exception as e:
29
+ st.write(f"⚠️ Could not download FAISS index files: {e}")
 
 
30
 
31
+ def _load_index(self):
32
+ try:
33
+ self.index = FAISS.load_local(self.local_index_dir, self.embedding_model)
34
+ st.write("βœ… Loaded FAISS index from local")
35
+ except Exception:
36
+ st.write("ℹ️ No local FAISS index found, starting with empty index")
37
+ self.index = None
 
 
 
 
38
 
39
+ def save_index(self):
40
+ if self.index is not None:
41
+ self.index.save_local(self.local_index_dir)
42
+ st.write("βœ… Saved FAISS index locally")
43
+ self._upload_index_files()
 
 
 
 
44
  else:
45
+ st.write("⚠️ No FAISS index to save")
46
 
47
+ def _upload_index_files(self):
48
+ repo_local_path = "/tmp/hf_dataset_repo_clone"
49
+ if os.path.exists(repo_local_path):
50
+ shutil.rmtree(repo_local_path)
51
 
52
+ repo = Repository(local_dir=repo_local_path, clone_from=self.hf_repo_id, use_auth_token=self.hf_token)
 
 
53
 
54
+ shutil.copy(os.path.join(self.local_index_dir, "index.faiss"), os.path.join(repo_local_path, "index.faiss"))
55
+ shutil.copy(os.path.join(self.local_index_dir, "index.pkl"), os.path.join(repo_local_path, "index.pkl"))
 
 
56
 
57
+ repo.git_add(auto_lfs_track=True)
58
+ repo.git_commit("Update FAISS index files")
59
+ repo.git_push()
60
+ st.write("βœ… Uploaded FAISS index files to HF repo")
 
 
 
61
 
62
+ def add_documents(self, docs):
63
+ if self.index is None:
64
+ self.index = FAISS.from_documents(docs, self.embedding_model)
65
+ st.write("βœ… Created new FAISS index")
66
+ else:
67
+ self.index.add_documents(docs)
68
+ st.write("βœ… Added documents to FAISS index")
69
 
70
+ self.save_index()
71
+
72
+ def similarity_search(self, query, k=5):
73
+ if self.index is None:
74
+ st.write("⚠️ No index found, returning empty results")
75
+ return []
76
+ return self.index.similarity_search(query, k=k)