File size: 7,979 Bytes
3e0b264
14c46c2
e7c776a
 
 
14c46c2
e7c776a
f98ed18
831f2ed
3e0b264
14c46c2
 
 
 
 
 
 
 
e7c776a
 
14c46c2
 
 
3e0b264
14c46c2
 
ab5abc8
 
 
 
14c46c2
 
af5fb28
14c46c2
af5fb28
3e0b264
14c46c2
 
e7c776a
 
af5fb28
e7c776a
 
 
6f73f4e
4698c74
 
6f73f4e
 
d9c7a9a
af5fb28
91713c8
af5fb28
14c46c2
3e0b264
14c46c2
 
 
af5fb28
14c46c2
3e0b264
af5fb28
3e0b264
14c46c2
e7c776a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5fb28
e7c776a
 
0f8eb9f
3e0b264
14c46c2
 
e7c776a
 
af5fb28
e7c776a
 
14c46c2
af5fb28
14c46c2
 
af5fb28
14c46c2
 
 
 
af5fb28
14c46c2
 
ebdb41f
b0051f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebdb41f
 
 
 
831f2ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0051f3
af5fb28
ebdb41f
 
831f2ed
 
 
 
 
 
 
 
 
 
 
 
 
ebdb41f
831f2ed
 
 
 
ebdb41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
831f2ed
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import os
import shutil
import json
from datetime import date # Assuming you'll use this for the date object
from huggingface_hub import hf_hub_download, HfApi # Import HfApi
from langchain_community.vectorstores.faiss import FAISS
from langchain.docstore.document import Document # Import Document
import streamlit as st
import tempfile

class HFVectorDB:
    def __init__(self, hf_repo_id, hf_token, local_index_dir="/tmp/vector_index", embedding_model=None):
        self.hf_repo_id = hf_repo_id
        self.hf_token = hf_token
        self.local_index_dir = local_index_dir
        self.embedding_model = embedding_model
        os.makedirs(self.local_index_dir, exist_ok=True)
        self.index = None
        # Initialize HfApi
        self.api = HfApi(token=self.hf_token)
        # Download index files from HF repo (if exist)
        self._download_index_files()
        self._load_index()

    def _download_index_files(self):
        try:
            # Using self.api.hf_hub_download for consistency
            faiss_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.faiss", repo_type="dataset")
            pkl_path = self.api.hf_hub_download(repo_id=self.hf_repo_id, filename="index.pkl", repo_type="dataset")

            shutil.copy(faiss_path, os.path.join(self.local_index_dir, "index.faiss"))
            shutil.copy(pkl_path, os.path.join(self.local_index_dir, "index.pkl"))
            st.sidebar.write(f"✅ Downloaded FAISS index files from HF repo using HfApi to {self.local_index_dir}")
        except Exception as e:
            st.sidebar.write(f"⚠️ Could not download FAISS index files: {e}")

    def _load_index(self):
        try:
            # Ensure embedding_model is provided when loading a FAISS index
            if self.embedding_model is None:
                st.sidebar.error("Error: embedding_model must be provided to load a FAISS index.")
                self.index = None
                return

            # Add more detailed logging before the load attempt
            # st.write(f"Attempting to load FAISS index from: {self.local_index_dir}")
            # st.write(f"Contents of local_index_dir: {os.listdir(self.local_index_dir)}")

            
            self.index = FAISS.load_local(self.local_index_dir, self.embedding_model,allow_dangerous_deserialization=True)
            st.sidebar.write("✅ Loaded FAISS index from local")
        except Exception as e:
            st.sidebar.write(f"ℹ️ No local FAISS index found, starting with empty index {e}")
            self.index = None

    def save_index(self):
        if self.index is not None:
            self.index.save_local(self.local_index_dir)
            st.sidebar.write("✅ Saved FAISS index locally")
            self._upload_index_files()
        else:
            st.sidebar.write("⚠️ No FAISS index to save")

    def _upload_index_files(self):
        try:
            # Upload index.faiss
            self.api.upload_file(
                path_or_fileobj=os.path.join(self.local_index_dir, "index.faiss"),
                path_in_repo="index.faiss",
                repo_id=self.hf_repo_id,
                repo_type="dataset",
                commit_message="Update FAISS index.faiss"
            )
            # Upload index.pkl
            self.api.upload_file(
                path_or_fileobj=os.path.join(self.local_index_dir, "index.pkl"),
                path_in_repo="index.pkl",
                repo_id=self.hf_repo_id,
                repo_type="dataset",
                commit_message="Update FAISS index.pkl"
            )
            st.sidebar.write("✅ Uploaded FAISS index files to HF repo using HfApi")
        except Exception as e:
            st.error(f"❌ Error uploading FAISS index files: {e}")


    def add_documents(self, docs):
        if self.index is None:
            # Ensure embedding_model is provided when creating a new FAISS index
            if self.embedding_model is None:
                st.sidebar.error("Error: embedding_model must be provided to create a new FAISS index.")
                return

            self.index = FAISS.from_documents(docs, self.embedding_model)
            st.sidebar.write("✅ Created new FAISS index")
        else:
            self.index.add_documents(docs)
            st.sidebar.write("✅ Added documents to FAISS index")
        self.save_index()

    def similarity_search(self, query, k=5):
        if self.index is None:
            st.sidebar.write("⚠️ No index found, returning empty results")
            return []
        return self.index.similarity_search(query, k=k)

    def upload_top_picks_jsonl(self, path="top_picks.jsonl"):
        """Uploads the top_picks.jsonl to the HF dataset repo"""
        try:
            self.api.upload_file(
                path_or_fileobj=path,
                path_in_repo="top_picks.jsonl",
                repo_id=self.hf_repo_id,
                repo_type="dataset",
                commit_message="Update top_picks.jsonl"
            )
            st.sidebar.write("✅ Uploaded top_picks.jsonl to HF repo")
        except Exception as e:
            st.sidebar.error(f"❌ Error uploading top_picks.jsonl: {e}")



def save_top_picks_json(vector_db,top_picks, date, path="top_picks.jsonl"):
    record = {
        "date": date.isoformat(),
        "top_picks": top_picks
    }

    # Step 1: Download existing top_picks.jsonl from HF
    try:
        tmp_file = hf_hub_download(
            repo_id=vector_db.hf_repo_id,
            filename="top_picks.jsonl",
            repo_type="dataset",
            token=vector_db.hf_token,
            cache_dir=tempfile.mkdtemp()
        )
        with open(tmp_file, "r") as f:
            existing_lines = f.readlines()
    except Exception as e:
        st.sidebar.warning(f"⚠️ No existing top_picks.jsonl found or failed to download: {e}")
        existing_lines = []

    # Step 2: Parse and filter existing records
    records = []
    for line in existing_lines:
        try:
            rec = json.loads(line)
            if rec["date"] != date.isoformat():  # Remove existing entry for the same date
                records.append(rec)
        except json.JSONDecodeError:
            continue
    
    # Step 3: Add current record
    records.append(record)

    # Step 4: Overwrite local top_picks.jsonl
    with open(path, "w") as f:
        for rec in records:
            f.write(json.dumps(rec) + "\n")

    vector_db.upload_top_picks_jsonl(path)
    st.sidebar.write(f"✅ Saved top picks to {path}")

def add_top_picks_to_vector_db(vector_db, top_picks, date):
    existing_docs = []
    if vector_db.index is not None:
        try:
            existing_docs = vector_db.index.docstore._dict.values()
        except Exception:
            pass

    existing_keys = set()
    for doc in existing_docs:
        key = (doc.metadata.get("ticker"), doc.metadata.get("date"))
        existing_keys.add(key)

    new_docs = []
    for pick in top_picks:
        key = (pick["ticker"], date.isoformat())
        if key in existing_keys:
            continue  # Skip duplicates

        content = (
            f"{pick['company']} ({pick['ticker']}):\n"
            f"Sentiment: {pick['sentiment']}\n"
            f"Critical News: {pick['critical_news']}\n"
            f"Impact: {pick['impact_summary']}\n"
            f"Action: {pick['action']}\n"
            f"Reason: {pick['reason']}"
        )
        metadata = {
            "ticker": pick["ticker"],
            "company": pick["company"],
            "sentiment": pick["sentiment"],
            "action": pick["action"],
            "date": date.isoformat()
        }
        new_docs.append(Document(page_content=content, metadata=metadata))

    if new_docs:
        vector_db.add_documents(new_docs)
        st.sidebar.write(f"✅ Added {len(new_docs)} new documents to vector DB")
    else:
        st.sidebar.write("ℹ️ No new documents to add to vector DB")