wakeupmh commited on
Commit
17a97cf
·
1 Parent(s): 87866cd

refactor: use better model

Browse files
Files changed (1) hide show
  1. app.py +27 -38
app.py CHANGED
@@ -20,7 +20,7 @@ logging.basicConfig(level=logging.INFO)
20
  DATA_DIR = "/data" if os.path.exists("/data") else "."
21
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
22
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
23
- MODEL_PATH = "google/mt5-base"
24
 
25
  # Constants for better maintainability
26
  MAX_ABSTRACT_LENGTH = 1000
@@ -182,7 +182,7 @@ class ModelHandler:
182
 
183
  @st.cache_resource
184
  def load_model(self):
185
- """Load model with improved error handling and resource management"""
186
  if self.model is None:
187
  try:
188
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
@@ -199,13 +199,26 @@ class ModelHandler:
199
  return True
200
 
201
  def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
202
- """Generate answer with improved prompt engineering and parameters"""
203
  if not self.load_model():
204
  return "Error: Model loading failed. Please try again later."
205
 
206
  try:
207
- # Improved prompt template
208
- input_text = self._create_enhanced_prompt(question, context)
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  inputs = self.tokenizer(
211
  input_text,
@@ -219,22 +232,22 @@ class ModelHandler:
219
  outputs = self.model.generate(
220
  **inputs,
221
  max_length=max_length,
222
- min_length=200,
223
- num_beams=4,
224
- length_penalty=1.5,
225
- temperature=0.7,
226
- repetition_penalty=1.3,
227
  early_stopping=True,
228
- no_repeat_ngram_size=3,
229
  do_sample=True,
230
- top_k=40,
231
- top_p=0.95
232
  )
233
 
234
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
235
  response = TextProcessor.clean_text(response)
236
 
237
- if len(response.strip()) < 100:
238
  return self._get_fallback_response()
239
 
240
  return self._format_response(response)
@@ -243,30 +256,6 @@ class ModelHandler:
243
  logging.error(f"Error generating response: {str(e)}")
244
  return "Error: Could not generate response. Please try again."
245
 
246
- @staticmethod
247
- def _create_enhanced_prompt(question: str, context: str) -> str:
248
- """Create an enhanced prompt for better response quality"""
249
- return f"""Context: {context}
250
-
251
- Question: {question}
252
-
253
- Instructions:
254
- 1. Provide a clear, evidence-based answer
255
- 2. Include specific findings from the research
256
- 3. Explain practical implications
257
- 4. Use accessible language
258
- 5. Address the question directly
259
- 6. Include relevant examples
260
-
261
- Response should be:
262
- - Accurate and scientific
263
- - Easy to understand
264
- - Practical and applicable
265
- - Respectful of neurodiversity
266
- - Supported by the provided research
267
-
268
- Generate a comprehensive response:"""
269
-
270
  @staticmethod
271
  def _get_fallback_response() -> str:
272
  """Provide a structured fallback response"""
 
20
  DATA_DIR = "/data" if os.path.exists("/data") else "."
21
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
22
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
23
+ MODEL_PATH = "google/flan-t5-small"
24
 
25
  # Constants for better maintainability
26
  MAX_ABSTRACT_LENGTH = 1000
 
182
 
183
  @st.cache_resource
184
  def load_model(self):
185
+ """Load FLAN-T5 Small model with optimized settings"""
186
  if self.model is None:
187
  try:
188
  self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
 
199
  return True
200
 
201
  def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
202
+ """Generate answer with FLAN-T5 optimized parameters"""
203
  if not self.load_model():
204
  return "Error: Model loading failed. Please try again later."
205
 
206
  try:
207
+ # FLAN-T5 responds better to direct instruction prompts
208
+ input_text = f"""Answer the following question about autism using the provided research context.
209
+ Research Context:
210
+ {context}
211
+
212
+ Question: {question}
213
+
214
+ Instructions:
215
+ - Be specific and evidence-based
216
+ - Use clear, accessible language
217
+ - Focus on practical implications
218
+ - Cite research when relevant
219
+ - Be respectful of neurodiversity
220
+
221
+ Answer:"""
222
 
223
  inputs = self.tokenizer(
224
  input_text,
 
232
  outputs = self.model.generate(
233
  **inputs,
234
  max_length=max_length,
235
+ min_length=100, # Reduzido para FLAN-T5 Small
236
+ num_beams=3, # Ajustado para melhor performance
237
+ length_penalty=1.0, # Mais neutro para respostas concisas
238
+ temperature=0.6, # Mais determinístico
239
+ repetition_penalty=1.2,
240
  early_stopping=True,
241
+ no_repeat_ngram_size=2,
242
  do_sample=True,
243
+ top_k=30,
244
+ top_p=0.92
245
  )
246
 
247
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
248
  response = TextProcessor.clean_text(response)
249
 
250
+ if len(response.strip()) < 50: # Ajustado para respostas mais curtas do FLAN-T5
251
  return self._get_fallback_response()
252
 
253
  return self._format_response(response)
 
256
  logging.error(f"Error generating response: {str(e)}")
257
  return "Error: Could not generate response. Please try again."
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  @staticmethod
260
  def _get_fallback_response() -> str:
261
  """Provide a structured fallback response"""