EricGEGE commited on
Commit
b8255a8
·
verified ·
1 Parent(s): a3ff768

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +195 -0
  3. requirements.txt +12 -0
  4. vectordb/1.index +3 -0
  5. vectordb/1.pkl +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vectordb/1.index filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Tuple
4
+ from functools import cached_property
5
+ from pydantic import BaseModel, Field
6
+ from openai import OpenAI
7
+ import faiss
8
+ import pickle
9
+ import numpy as np
10
+ from dotenv import load_dotenv
11
+ import gradio as gr
12
+ from datetime import datetime
13
+ from sentence_transformers import SentenceTransformer
14
+
15
+ # Load environment variables from .env file
16
+ load_dotenv()
17
+
18
+ # Configure logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ NO_DATA_MESSAGE = "I apologize, but I encountered an error processing your request."
23
+
24
+ class LocalEmbedding:
25
+ """Local embedding model wrapper"""
26
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
27
+ self.model = SentenceTransformer(model_name)
28
+ self.vector_dim = self.model.get_sentence_embedding_dimension()
29
+
30
+ def get_embedding(self, text: str) -> List[float]:
31
+ """Get embedding using local model"""
32
+ try:
33
+ embedding = self.model.encode(text)
34
+ return embedding.tolist()
35
+ except Exception as e:
36
+ logger.error(f"Error getting embedding: {e}")
37
+ return []
38
+
39
+ class DeepSeekChat(BaseModel):
40
+ """DeepSeek chat model wrapper"""
41
+ api_key: str = Field(default=os.getenv("DEEPSEEK_API_KEY"))
42
+ base_url: str = Field(default="https://api.siliconflow.cn/v1")
43
+
44
+ class Config:
45
+ """Pydantic config class"""
46
+ arbitrary_types_allowed = True
47
+
48
+ @cached_property
49
+ def client(self) -> OpenAI:
50
+ """Create and cache OpenAI client instance"""
51
+ return OpenAI(api_key=self.api_key, base_url=self.base_url)
52
+
53
+ def chat(
54
+ self,
55
+ system_message: str,
56
+ user_message: str,
57
+ context: str = "",
58
+ model: str = "deepseek-ai/DeepSeek-V3",
59
+ max_tokens: int = 1024,
60
+ temperature: float = 0.7,
61
+ ) -> str:
62
+ """Send chat request to DeepSeek API"""
63
+ messages = []
64
+
65
+ # Add system message if provided
66
+ if system_message:
67
+ messages.append({"role": "system", "content": system_message})
68
+
69
+ # Add context if provided
70
+ if context:
71
+ messages.append({"role": "user", "content": context})
72
+
73
+ # Add user message
74
+ messages.append({"role": "user", "content": user_message})
75
+
76
+ try:
77
+ response = self.client.chat.completions.create(
78
+ model=model,
79
+ messages=messages,
80
+ max_tokens=max_tokens,
81
+ temperature=temperature,
82
+ )
83
+ return response.choices[0].message.content
84
+ except Exception as e:
85
+ logger.error(f"Error in DeepSeek API call: {e}")
86
+ return NO_DATA_MESSAGE
87
+
88
+ class PDFChatbot:
89
+ def __init__(self, index_path: str, texts_path: str, model_name: str = "all-MiniLM-L6-v2"):
90
+ if not os.getenv("DEEPSEEK_API_KEY"):
91
+ raise ValueError("DEEPSEEK_API_KEY not found in .env file")
92
+
93
+ # Initialize models
94
+ logger.info("Initializing models...")
95
+ self.chat_model = DeepSeekChat()
96
+ self.embedding_model = LocalEmbedding(model_name)
97
+
98
+ # Load vector database
99
+ logger.info("Loading vector database...")
100
+ self.index = faiss.read_index(index_path)
101
+ with open(texts_path, 'rb') as f:
102
+ self.texts = pickle.load(f)
103
+
104
+ # Chat settings
105
+ self.system_message = """You are a knowledgeable AI assistant that helps users understand the content of the provided document.
106
+ Use the context provided to answer questions accurately and comprehensively. If the answer cannot be found in the context,
107
+ clearly state that the information is not available in the document."""
108
+
109
+ # Create conversation log file with timestamp
110
+ self.log_file = f"pdf_chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
111
+ self.log_conversation("Conversation started")
112
+
113
+ def log_conversation(self, message, role="system"):
114
+ """Log conversation with timestamp to file"""
115
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
116
+ with open(self.log_file, "a", encoding="utf-8") as f:
117
+ f.write(f"[{timestamp}] {role}: {message}\n")
118
+
119
+ def get_relevant_context(self, query: str, k: int = 3) -> str:
120
+ """Get most relevant context for the query"""
121
+ try:
122
+ # Get query embedding
123
+ query_embedding = self.embedding_model.get_embedding(query)
124
+ if not query_embedding:
125
+ return ""
126
+
127
+ # Search for similar contexts
128
+ query_vector = np.array([query_embedding]).astype('float32')
129
+ distances, indices = self.index.search(query_vector, k)
130
+
131
+ # Combine relevant contexts
132
+ relevant_texts = [self.texts[i] for i in indices[0]]
133
+ return "\n".join(relevant_texts)
134
+ except Exception as e:
135
+ logger.error(f"Error getting relevant context: {e}")
136
+ return ""
137
+
138
+ def chat(self, message, history):
139
+ """Process chat message and return response"""
140
+ try:
141
+ # Log user message
142
+ self.log_conversation(message, "user")
143
+
144
+ # Get relevant context
145
+ context = self.get_relevant_context(message)
146
+
147
+ # If context is found, add it to the prompt
148
+ context_prompt = f"Based on the following context from the document:\n{context}\n\nPlease answer the question." if context else ""
149
+
150
+ # Get response from DeepSeek
151
+ response = self.chat_model.chat(
152
+ system_message=self.system_message,
153
+ user_message=message,
154
+ context=context_prompt
155
+ )
156
+
157
+ # Log assistant response
158
+ self.log_conversation(response, "assistant")
159
+
160
+ return response
161
+ except Exception as e:
162
+ logger.error(f"Error in chat: {e}")
163
+ return NO_DATA_MESSAGE
164
+
165
+ def main():
166
+ try:
167
+ # Replace these paths with your actual vector database files
168
+ index_path = "vectordb/1.index"
169
+ texts_path = "vectordb/1.pkl"
170
+
171
+ # Initialize chatbot
172
+ chatbot = PDFChatbot(index_path, texts_path)
173
+
174
+ # Create Gradio interface
175
+ iface = gr.ChatInterface(
176
+ fn=chatbot.chat,
177
+ title="PDF Document Assistant",
178
+ description="Ask questions about the loaded PDF document. I'll help you understand its contents.",
179
+ theme=gr.themes.Soft(),
180
+ examples=[
181
+ "What is the main topic of this document?",
182
+ "Can you summarize the key points?",
183
+ "What are the conclusions drawn in this document?"
184
+ ],
185
+ )
186
+
187
+ # Launch the interface
188
+ iface.launch(share=False)
189
+
190
+ except Exception as e:
191
+ logger.error(f"Failed to initialize chatbot: {e}")
192
+ raise
193
+
194
+ if __name__ == "__main__":
195
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ faiss_cpu
2
+ gradio
3
+ numpy
4
+ openai
5
+ pandas
6
+ pydantic
7
+ PyPDF2
8
+ python-dotenv
9
+ python-telegram-bot
10
+ scikit_learn
11
+ sentence_transformers
12
+ tqdm
vectordb/1.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c6a170d5a1183cccef3a5a201a23decb25d25080059541585eb7f00eb2baef1
3
+ size 617517
vectordb/1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f2e21530bcb8e482312277bfb26ba13fc22b619c4fa0aae2cee3b832d760226
3
+ size 402981