Spaces:
Sleeping
Sleeping
feat: add BART
Browse files- app.py +50 -47
- requirements.txt +1 -4
app.py
CHANGED
@@ -16,23 +16,22 @@ logging.basicConfig(level=logging.INFO)
|
|
16 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
17 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
18 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
19 |
-
MODEL_PATH = "
|
20 |
|
21 |
@st.cache_resource
|
22 |
def load_local_model():
|
23 |
"""Load the local Hugging Face model"""
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
return model, tokenizer
|
36 |
|
37 |
def fetch_arxiv_papers(query, max_results=5):
|
38 |
"""Fetch papers from arXiv"""
|
@@ -144,6 +143,9 @@ def generate_answer(question, context, max_length=512):
|
|
144 |
"""Generate a comprehensive answer using the local model"""
|
145 |
model, tokenizer = load_local_model()
|
146 |
|
|
|
|
|
|
|
147 |
# Format the context as a structured query
|
148 |
prompt = f"""Summarize the following research about autism and answer the question.
|
149 |
|
@@ -152,7 +154,7 @@ Research Context:
|
|
152 |
|
153 |
Question: {question}
|
154 |
|
155 |
-
|
156 |
1. Main findings from the research
|
157 |
2. Research methods used
|
158 |
3. Clinical implications
|
@@ -160,41 +162,42 @@ Provide a detailed answer that includes:
|
|
160 |
|
161 |
If the research doesn't address the question directly, explain what information is missing."""
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
# Move inputs to the same device as model
|
167 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
168 |
-
|
169 |
-
with torch.inference_mode():
|
170 |
-
outputs = model.generate(
|
171 |
-
**inputs,
|
172 |
-
max_length=max_length,
|
173 |
-
min_length=200, # Ensure longer responses
|
174 |
-
num_beams=5,
|
175 |
-
length_penalty=2.0, # Encourage even longer responses
|
176 |
-
temperature=0.7,
|
177 |
-
no_repeat_ngram_size=3,
|
178 |
-
repetition_penalty=1.3,
|
179 |
-
early_stopping=True
|
180 |
-
)
|
181 |
-
|
182 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
183 |
-
|
184 |
-
# If response is too short or empty, provide a fallback message
|
185 |
-
if len(response.strip()) < 100:
|
186 |
-
return """I apologize, but I couldn't generate a specific answer from the research papers provided.
|
187 |
-
This might be because:
|
188 |
-
1. The research papers don't directly address your question
|
189 |
-
2. The context needs more specific information
|
190 |
-
3. The question might need to be more specific
|
191 |
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
-
|
|
|
|
|
198 |
|
199 |
# Streamlit App
|
200 |
st.title("🧩 AMA Autism")
|
|
|
16 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
17 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
18 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
19 |
+
MODEL_PATH = "t5-small" # Changed to T5-small for better CPU compatibility
|
20 |
|
21 |
@st.cache_resource
|
22 |
def load_local_model():
|
23 |
"""Load the local Hugging Face model"""
|
24 |
+
try:
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
26 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
27 |
+
MODEL_PATH,
|
28 |
+
device_map={"": "cpu"}, # Force CPU
|
29 |
+
torch_dtype=torch.float32
|
30 |
+
)
|
31 |
+
return model, tokenizer
|
32 |
+
except Exception as e:
|
33 |
+
st.error(f"Error loading model: {str(e)}")
|
34 |
+
return None, None
|
|
|
35 |
|
36 |
def fetch_arxiv_papers(query, max_results=5):
|
37 |
"""Fetch papers from arXiv"""
|
|
|
143 |
"""Generate a comprehensive answer using the local model"""
|
144 |
model, tokenizer = load_local_model()
|
145 |
|
146 |
+
if model is None or tokenizer is None:
|
147 |
+
return "Error: Could not load the model. Please try again later."
|
148 |
+
|
149 |
# Format the context as a structured query
|
150 |
prompt = f"""Summarize the following research about autism and answer the question.
|
151 |
|
|
|
154 |
|
155 |
Question: {question}
|
156 |
|
157 |
+
Instructions: Based on the research context above, provide a comprehensive answer that covers:
|
158 |
1. Main findings from the research
|
159 |
2. Research methods used
|
160 |
3. Clinical implications
|
|
|
162 |
|
163 |
If the research doesn't address the question directly, explain what information is missing."""
|
164 |
|
165 |
+
try:
|
166 |
+
# Generate response
|
167 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
+
with torch.inference_mode():
|
170 |
+
outputs = model.generate(
|
171 |
+
**inputs,
|
172 |
+
max_length=max_length,
|
173 |
+
min_length=100,
|
174 |
+
num_beams=4,
|
175 |
+
length_penalty=1.5,
|
176 |
+
temperature=0.7,
|
177 |
+
repetition_penalty=1.2,
|
178 |
+
early_stopping=True
|
179 |
+
)
|
180 |
+
|
181 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
182 |
+
|
183 |
+
# If response is too short or empty, provide a fallback message
|
184 |
+
if len(response.strip()) < 50:
|
185 |
+
return """I apologize, but I couldn't generate a specific answer from the research papers provided.
|
186 |
+
This might be because:
|
187 |
+
1. The research papers don't directly address your question
|
188 |
+
2. The context needs more specific information
|
189 |
+
3. The question might need to be more specific
|
190 |
+
|
191 |
+
Please try rephrasing your question or ask about a more specific aspect of autism."""
|
192 |
+
|
193 |
+
# Format the response for better readability
|
194 |
+
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
|
195 |
+
|
196 |
+
return formatted_response
|
197 |
|
198 |
+
except Exception as e:
|
199 |
+
st.error(f"Error generating response: {str(e)}")
|
200 |
+
return "Error: Could not generate response. Please try again with a different question."
|
201 |
|
202 |
# Streamlit App
|
203 |
st.title("🧩 AMA Autism")
|
requirements.txt
CHANGED
@@ -4,10 +4,7 @@ datasets>=2.17.0
|
|
4 |
--extra-index-url https://download.pytorch.org/whl/cpu
|
5 |
torch>=2.2.0
|
6 |
accelerate>=0.26.0
|
7 |
-
safetensors>=0.4.1
|
8 |
numpy>=1.24.0
|
9 |
pandas>=2.2.0
|
10 |
requests>=2.31.0
|
11 |
-
arxiv>=2.1.0
|
12 |
-
lancedb>=0.3.3
|
13 |
-
tantivy>=0.19.2
|
|
|
4 |
--extra-index-url https://download.pytorch.org/whl/cpu
|
5 |
torch>=2.2.0
|
6 |
accelerate>=0.26.0
|
|
|
7 |
numpy>=1.24.0
|
8 |
pandas>=2.2.0
|
9 |
requests>=2.31.0
|
10 |
+
arxiv>=2.1.0
|
|
|
|