Spaces:
Sleeping
Sleeping
refactor: use better model
Browse files
app.py
CHANGED
@@ -20,7 +20,7 @@ logging.basicConfig(level=logging.INFO)
|
|
20 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
21 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
22 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
23 |
-
MODEL_PATH = "google/
|
24 |
|
25 |
# Constants for better maintainability
|
26 |
MAX_ABSTRACT_LENGTH = 1000
|
@@ -182,7 +182,7 @@ class ModelHandler:
|
|
182 |
|
183 |
@st.cache_resource
|
184 |
def load_model(self):
|
185 |
-
"""Load model with
|
186 |
if self.model is None:
|
187 |
try:
|
188 |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
@@ -199,13 +199,26 @@ class ModelHandler:
|
|
199 |
return True
|
200 |
|
201 |
def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
|
202 |
-
"""Generate answer with
|
203 |
if not self.load_model():
|
204 |
return "Error: Model loading failed. Please try again later."
|
205 |
|
206 |
try:
|
207 |
-
#
|
208 |
-
input_text =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
inputs = self.tokenizer(
|
211 |
input_text,
|
@@ -219,22 +232,22 @@ class ModelHandler:
|
|
219 |
outputs = self.model.generate(
|
220 |
**inputs,
|
221 |
max_length=max_length,
|
222 |
-
min_length=
|
223 |
-
num_beams=
|
224 |
-
length_penalty=1.
|
225 |
-
temperature=0.
|
226 |
-
repetition_penalty=1.
|
227 |
early_stopping=True,
|
228 |
-
no_repeat_ngram_size=
|
229 |
do_sample=True,
|
230 |
-
top_k=
|
231 |
-
top_p=0.
|
232 |
)
|
233 |
|
234 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
235 |
response = TextProcessor.clean_text(response)
|
236 |
|
237 |
-
if len(response.strip()) <
|
238 |
return self._get_fallback_response()
|
239 |
|
240 |
return self._format_response(response)
|
@@ -243,30 +256,6 @@ class ModelHandler:
|
|
243 |
logging.error(f"Error generating response: {str(e)}")
|
244 |
return "Error: Could not generate response. Please try again."
|
245 |
|
246 |
-
@staticmethod
|
247 |
-
def _create_enhanced_prompt(question: str, context: str) -> str:
|
248 |
-
"""Create an enhanced prompt for better response quality"""
|
249 |
-
return f"""Context: {context}
|
250 |
-
|
251 |
-
Question: {question}
|
252 |
-
|
253 |
-
Instructions:
|
254 |
-
1. Provide a clear, evidence-based answer
|
255 |
-
2. Include specific findings from the research
|
256 |
-
3. Explain practical implications
|
257 |
-
4. Use accessible language
|
258 |
-
5. Address the question directly
|
259 |
-
6. Include relevant examples
|
260 |
-
|
261 |
-
Response should be:
|
262 |
-
- Accurate and scientific
|
263 |
-
- Easy to understand
|
264 |
-
- Practical and applicable
|
265 |
-
- Respectful of neurodiversity
|
266 |
-
- Supported by the provided research
|
267 |
-
|
268 |
-
Generate a comprehensive response:"""
|
269 |
-
|
270 |
@staticmethod
|
271 |
def _get_fallback_response() -> str:
|
272 |
"""Provide a structured fallback response"""
|
|
|
20 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
21 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
22 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
23 |
+
MODEL_PATH = "google/flan-t5-small"
|
24 |
|
25 |
# Constants for better maintainability
|
26 |
MAX_ABSTRACT_LENGTH = 1000
|
|
|
182 |
|
183 |
@st.cache_resource
|
184 |
def load_model(self):
|
185 |
+
"""Load FLAN-T5 Small model with optimized settings"""
|
186 |
if self.model is None:
|
187 |
try:
|
188 |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
|
|
199 |
return True
|
200 |
|
201 |
def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
|
202 |
+
"""Generate answer with FLAN-T5 optimized parameters"""
|
203 |
if not self.load_model():
|
204 |
return "Error: Model loading failed. Please try again later."
|
205 |
|
206 |
try:
|
207 |
+
# FLAN-T5 responds better to direct instruction prompts
|
208 |
+
input_text = f"""Answer the following question about autism using the provided research context.
|
209 |
+
Research Context:
|
210 |
+
{context}
|
211 |
+
|
212 |
+
Question: {question}
|
213 |
+
|
214 |
+
Instructions:
|
215 |
+
- Be specific and evidence-based
|
216 |
+
- Use clear, accessible language
|
217 |
+
- Focus on practical implications
|
218 |
+
- Cite research when relevant
|
219 |
+
- Be respectful of neurodiversity
|
220 |
+
|
221 |
+
Answer:"""
|
222 |
|
223 |
inputs = self.tokenizer(
|
224 |
input_text,
|
|
|
232 |
outputs = self.model.generate(
|
233 |
**inputs,
|
234 |
max_length=max_length,
|
235 |
+
min_length=100, # Reduzido para FLAN-T5 Small
|
236 |
+
num_beams=3, # Ajustado para melhor performance
|
237 |
+
length_penalty=1.0, # Mais neutro para respostas concisas
|
238 |
+
temperature=0.6, # Mais determinístico
|
239 |
+
repetition_penalty=1.2,
|
240 |
early_stopping=True,
|
241 |
+
no_repeat_ngram_size=2,
|
242 |
do_sample=True,
|
243 |
+
top_k=30,
|
244 |
+
top_p=0.92
|
245 |
)
|
246 |
|
247 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
248 |
response = TextProcessor.clean_text(response)
|
249 |
|
250 |
+
if len(response.strip()) < 50: # Ajustado para respostas mais curtas do FLAN-T5
|
251 |
return self._get_fallback_response()
|
252 |
|
253 |
return self._format_response(response)
|
|
|
256 |
logging.error(f"Error generating response: {str(e)}")
|
257 |
return "Error: Could not generate response. Please try again."
|
258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
@staticmethod
|
260 |
def _get_fallback_response() -> str:
|
261 |
"""Provide a structured fallback response"""
|