wakeupmh commited on
Commit
a47c92e
·
1 Parent(s): 8081db6

feat: add BART

Browse files
Files changed (2) hide show
  1. app.py +50 -47
  2. requirements.txt +1 -4
app.py CHANGED
@@ -16,23 +16,22 @@ logging.basicConfig(level=logging.INFO)
16
  DATA_DIR = "/data" if os.path.exists("/data") else "."
17
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
18
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
19
- MODEL_PATH = "facebook/bart-large-cnn" # Changed to BART model which is better for summarization
20
 
21
  @st.cache_resource
22
  def load_local_model():
23
  """Load the local Hugging Face model"""
24
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
25
- model = AutoModelForSeq2SeqLM.from_pretrained(
26
- MODEL_PATH,
27
- torch_dtype=torch.float32,
28
- low_cpu_mem_usage=True,
29
- device_map=None # Let PyTorch handle device placement
30
- )
31
-
32
- # Move model to CPU explicitly
33
- model = model.cpu()
34
-
35
- return model, tokenizer
36
 
37
  def fetch_arxiv_papers(query, max_results=5):
38
  """Fetch papers from arXiv"""
@@ -144,6 +143,9 @@ def generate_answer(question, context, max_length=512):
144
  """Generate a comprehensive answer using the local model"""
145
  model, tokenizer = load_local_model()
146
 
 
 
 
147
  # Format the context as a structured query
148
  prompt = f"""Summarize the following research about autism and answer the question.
149
 
@@ -152,7 +154,7 @@ Research Context:
152
 
153
  Question: {question}
154
 
155
- Provide a detailed answer that includes:
156
  1. Main findings from the research
157
  2. Research methods used
158
  3. Clinical implications
@@ -160,41 +162,42 @@ Provide a detailed answer that includes:
160
 
161
  If the research doesn't address the question directly, explain what information is missing."""
162
 
163
- # Generate response
164
- inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
165
-
166
- # Move inputs to the same device as model
167
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
168
-
169
- with torch.inference_mode():
170
- outputs = model.generate(
171
- **inputs,
172
- max_length=max_length,
173
- min_length=200, # Ensure longer responses
174
- num_beams=5,
175
- length_penalty=2.0, # Encourage even longer responses
176
- temperature=0.7,
177
- no_repeat_ngram_size=3,
178
- repetition_penalty=1.3,
179
- early_stopping=True
180
- )
181
-
182
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
183
-
184
- # If response is too short or empty, provide a fallback message
185
- if len(response.strip()) < 100:
186
- return """I apologize, but I couldn't generate a specific answer from the research papers provided.
187
- This might be because:
188
- 1. The research papers don't directly address your question
189
- 2. The context needs more specific information
190
- 3. The question might need to be more specific
191
 
192
- Please try rephrasing your question or ask about a more specific aspect of autism."""
193
-
194
- # Format the response for better readability
195
- formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- return formatted_response
 
 
198
 
199
  # Streamlit App
200
  st.title("🧩 AMA Autism")
 
16
  DATA_DIR = "/data" if os.path.exists("/data") else "."
17
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
18
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
19
+ MODEL_PATH = "t5-small" # Changed to T5-small for better CPU compatibility
20
 
21
  @st.cache_resource
22
  def load_local_model():
23
  """Load the local Hugging Face model"""
24
+ try:
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
26
+ model = AutoModelForSeq2SeqLM.from_pretrained(
27
+ MODEL_PATH,
28
+ device_map={"": "cpu"}, # Force CPU
29
+ torch_dtype=torch.float32
30
+ )
31
+ return model, tokenizer
32
+ except Exception as e:
33
+ st.error(f"Error loading model: {str(e)}")
34
+ return None, None
 
35
 
36
  def fetch_arxiv_papers(query, max_results=5):
37
  """Fetch papers from arXiv"""
 
143
  """Generate a comprehensive answer using the local model"""
144
  model, tokenizer = load_local_model()
145
 
146
+ if model is None or tokenizer is None:
147
+ return "Error: Could not load the model. Please try again later."
148
+
149
  # Format the context as a structured query
150
  prompt = f"""Summarize the following research about autism and answer the question.
151
 
 
154
 
155
  Question: {question}
156
 
157
+ Instructions: Based on the research context above, provide a comprehensive answer that covers:
158
  1. Main findings from the research
159
  2. Research methods used
160
  3. Clinical implications
 
162
 
163
  If the research doesn't address the question directly, explain what information is missing."""
164
 
165
+ try:
166
+ # Generate response
167
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ with torch.inference_mode():
170
+ outputs = model.generate(
171
+ **inputs,
172
+ max_length=max_length,
173
+ min_length=100,
174
+ num_beams=4,
175
+ length_penalty=1.5,
176
+ temperature=0.7,
177
+ repetition_penalty=1.2,
178
+ early_stopping=True
179
+ )
180
+
181
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
182
+
183
+ # If response is too short or empty, provide a fallback message
184
+ if len(response.strip()) < 50:
185
+ return """I apologize, but I couldn't generate a specific answer from the research papers provided.
186
+ This might be because:
187
+ 1. The research papers don't directly address your question
188
+ 2. The context needs more specific information
189
+ 3. The question might need to be more specific
190
+
191
+ Please try rephrasing your question or ask about a more specific aspect of autism."""
192
+
193
+ # Format the response for better readability
194
+ formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
195
+
196
+ return formatted_response
197
 
198
+ except Exception as e:
199
+ st.error(f"Error generating response: {str(e)}")
200
+ return "Error: Could not generate response. Please try again with a different question."
201
 
202
  # Streamlit App
203
  st.title("🧩 AMA Autism")
requirements.txt CHANGED
@@ -4,10 +4,7 @@ datasets>=2.17.0
4
  --extra-index-url https://download.pytorch.org/whl/cpu
5
  torch>=2.2.0
6
  accelerate>=0.26.0
7
- safetensors>=0.4.1
8
  numpy>=1.24.0
9
  pandas>=2.2.0
10
  requests>=2.31.0
11
- arxiv>=2.1.0
12
- lancedb>=0.3.3
13
- tantivy>=0.19.2
 
4
  --extra-index-url https://download.pytorch.org/whl/cpu
5
  torch>=2.2.0
6
  accelerate>=0.26.0
 
7
  numpy>=1.24.0
8
  pandas>=2.2.0
9
  requests>=2.31.0
10
+ arxiv>=2.1.0