wakeupmh commited on
Commit
8f85101
·
1 Parent(s): 17a97cf

fix: class

Browse files
Files changed (1) hide show
  1. app.py +40 -36
app.py CHANGED
@@ -179,46 +179,51 @@ class ModelHandler:
179
  def __init__(self):
180
  self.model = None
181
  self.tokenizer = None
182
-
 
 
183
  @st.cache_resource
184
- def load_model(self):
185
  """Load FLAN-T5 Small model with optimized settings"""
186
- if self.model is None:
187
- try:
188
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
189
- self.model = T5ForConditionalGeneration.from_pretrained(
190
- MODEL_PATH,
191
- device_map={"": "cpu"},
192
- torch_dtype=torch.float32,
193
- low_cpu_mem_usage=True
194
- )
195
- return True
196
- except Exception as e:
197
- logging.error(f"Error loading model: {str(e)}")
198
- return False
199
- return True
 
 
200
 
201
  def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
202
  """Generate answer with FLAN-T5 optimized parameters"""
203
- if not self.load_model():
204
  return "Error: Model loading failed. Please try again later."
205
 
206
  try:
207
  # FLAN-T5 responds better to direct instruction prompts
208
  input_text = f"""Answer the following question about autism using the provided research context.
209
- Research Context:
210
- {context}
211
 
212
- Question: {question}
 
 
 
213
 
214
- Instructions:
215
- - Be specific and evidence-based
216
- - Use clear, accessible language
217
- - Focus on practical implications
218
- - Cite research when relevant
219
- - Be respectful of neurodiversity
220
 
221
- Answer:"""
222
 
223
  inputs = self.tokenizer(
224
  input_text,
@@ -232,10 +237,10 @@ class ModelHandler:
232
  outputs = self.model.generate(
233
  **inputs,
234
  max_length=max_length,
235
- min_length=100, # Reduzido para FLAN-T5 Small
236
- num_beams=3, # Ajustado para melhor performance
237
- length_penalty=1.0, # Mais neutro para respostas concisas
238
- temperature=0.6, # Mais determinístico
239
  repetition_penalty=1.2,
240
  early_stopping=True,
241
  no_repeat_ngram_size=2,
@@ -247,7 +252,7 @@ class ModelHandler:
247
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
248
  response = TextProcessor.clean_text(response)
249
 
250
- if len(response.strip()) < 50: # Ajustado para respostas mais curtas do FLAN-T5
251
  return self._get_fallback_response()
252
 
253
  return self._format_response(response)
@@ -259,10 +264,10 @@ class ModelHandler:
259
  @staticmethod
260
  def _get_fallback_response() -> str:
261
  """Provide a structured fallback response"""
262
- return """Based on the available research, I cannot provide a specific answer to your question. However, I can suggest:
263
 
264
- 1. Try rephrasing your question to focus on specific aspects of autism
265
- 2. Consider asking about:
266
  - Specific behaviors or characteristics
267
  - Intervention strategies
268
  - Research findings
@@ -273,7 +278,6 @@ This will help me provide more accurate, research-based information."""
273
  @staticmethod
274
  def _format_response(response: str) -> str:
275
  """Format the response for better readability"""
276
- # Add section headers
277
  sections = response.split('\n\n')
278
  formatted_sections = []
279
 
 
179
  def __init__(self):
180
  self.model = None
181
  self.tokenizer = None
182
+ self._initialize_model()
183
+
184
+ @staticmethod
185
  @st.cache_resource
186
+ def _load_model():
187
  """Load FLAN-T5 Small model with optimized settings"""
188
+ try:
189
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
190
+ model = T5ForConditionalGeneration.from_pretrained(
191
+ MODEL_PATH,
192
+ device_map={"": "cpu"},
193
+ torch_dtype=torch.float32,
194
+ low_cpu_mem_usage=True
195
+ )
196
+ return model, tokenizer
197
+ except Exception as e:
198
+ logging.error(f"Error loading model: {str(e)}")
199
+ return None, None
200
+
201
+ def _initialize_model(self):
202
+ """Initialize model and tokenizer"""
203
+ self.model, self.tokenizer = self._load_model()
204
 
205
  def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
206
  """Generate answer with FLAN-T5 optimized parameters"""
207
+ if self.model is None or self.tokenizer is None:
208
  return "Error: Model loading failed. Please try again later."
209
 
210
  try:
211
  # FLAN-T5 responds better to direct instruction prompts
212
  input_text = f"""Answer the following question about autism using the provided research context.
 
 
213
 
214
+ Research Context:
215
+ {context}
216
+
217
+ Question: {question}
218
 
219
+ Instructions:
220
+ - Be specific and evidence-based
221
+ - Use clear, accessible language
222
+ - Focus on practical implications
223
+ - Cite research when relevant
224
+ - Be respectful of neurodiversity
225
 
226
+ Answer:"""
227
 
228
  inputs = self.tokenizer(
229
  input_text,
 
237
  outputs = self.model.generate(
238
  **inputs,
239
  max_length=max_length,
240
+ min_length=100,
241
+ num_beams=3,
242
+ length_penalty=1.0,
243
+ temperature=0.6,
244
  repetition_penalty=1.2,
245
  early_stopping=True,
246
  no_repeat_ngram_size=2,
 
252
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
253
  response = TextProcessor.clean_text(response)
254
 
255
+ if len(response.strip()) < 50:
256
  return self._get_fallback_response()
257
 
258
  return self._format_response(response)
 
264
  @staticmethod
265
  def _get_fallback_response() -> str:
266
  """Provide a structured fallback response"""
267
+ return """Based on the available research, I cannot provide a specific answer to your question. Please try:
268
 
269
+ 1. Rephrasing your question to be more specific
270
+ 2. Asking about:
271
  - Specific behaviors or characteristics
272
  - Intervention strategies
273
  - Research findings
 
278
  @staticmethod
279
  def _format_response(response: str) -> str:
280
  """Format the response for better readability"""
 
281
  sections = response.split('\n\n')
282
  formatted_sections = []
283