Spaces:
Sleeping
Sleeping
refactor: use better model
Browse files
app.py
CHANGED
|
@@ -20,7 +20,7 @@ logging.basicConfig(level=logging.INFO)
|
|
| 20 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
| 21 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
| 22 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
| 23 |
-
MODEL_PATH = "google/
|
| 24 |
|
| 25 |
# Constants for better maintainability
|
| 26 |
MAX_ABSTRACT_LENGTH = 1000
|
|
@@ -182,7 +182,7 @@ class ModelHandler:
|
|
| 182 |
|
| 183 |
@st.cache_resource
|
| 184 |
def load_model(self):
|
| 185 |
-
"""Load model with
|
| 186 |
if self.model is None:
|
| 187 |
try:
|
| 188 |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
|
@@ -199,13 +199,26 @@ class ModelHandler:
|
|
| 199 |
return True
|
| 200 |
|
| 201 |
def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
|
| 202 |
-
"""Generate answer with
|
| 203 |
if not self.load_model():
|
| 204 |
return "Error: Model loading failed. Please try again later."
|
| 205 |
|
| 206 |
try:
|
| 207 |
-
#
|
| 208 |
-
input_text =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
|
| 210 |
inputs = self.tokenizer(
|
| 211 |
input_text,
|
|
@@ -219,22 +232,22 @@ class ModelHandler:
|
|
| 219 |
outputs = self.model.generate(
|
| 220 |
**inputs,
|
| 221 |
max_length=max_length,
|
| 222 |
-
min_length=
|
| 223 |
-
num_beams=
|
| 224 |
-
length_penalty=1.
|
| 225 |
-
temperature=0.
|
| 226 |
-
repetition_penalty=1.
|
| 227 |
early_stopping=True,
|
| 228 |
-
no_repeat_ngram_size=
|
| 229 |
do_sample=True,
|
| 230 |
-
top_k=
|
| 231 |
-
top_p=0.
|
| 232 |
)
|
| 233 |
|
| 234 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 235 |
response = TextProcessor.clean_text(response)
|
| 236 |
|
| 237 |
-
if len(response.strip()) <
|
| 238 |
return self._get_fallback_response()
|
| 239 |
|
| 240 |
return self._format_response(response)
|
|
@@ -243,30 +256,6 @@ class ModelHandler:
|
|
| 243 |
logging.error(f"Error generating response: {str(e)}")
|
| 244 |
return "Error: Could not generate response. Please try again."
|
| 245 |
|
| 246 |
-
@staticmethod
|
| 247 |
-
def _create_enhanced_prompt(question: str, context: str) -> str:
|
| 248 |
-
"""Create an enhanced prompt for better response quality"""
|
| 249 |
-
return f"""Context: {context}
|
| 250 |
-
|
| 251 |
-
Question: {question}
|
| 252 |
-
|
| 253 |
-
Instructions:
|
| 254 |
-
1. Provide a clear, evidence-based answer
|
| 255 |
-
2. Include specific findings from the research
|
| 256 |
-
3. Explain practical implications
|
| 257 |
-
4. Use accessible language
|
| 258 |
-
5. Address the question directly
|
| 259 |
-
6. Include relevant examples
|
| 260 |
-
|
| 261 |
-
Response should be:
|
| 262 |
-
- Accurate and scientific
|
| 263 |
-
- Easy to understand
|
| 264 |
-
- Practical and applicable
|
| 265 |
-
- Respectful of neurodiversity
|
| 266 |
-
- Supported by the provided research
|
| 267 |
-
|
| 268 |
-
Generate a comprehensive response:"""
|
| 269 |
-
|
| 270 |
@staticmethod
|
| 271 |
def _get_fallback_response() -> str:
|
| 272 |
"""Provide a structured fallback response"""
|
|
|
|
| 20 |
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
| 21 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
| 22 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
| 23 |
+
MODEL_PATH = "google/flan-t5-small"
|
| 24 |
|
| 25 |
# Constants for better maintainability
|
| 26 |
MAX_ABSTRACT_LENGTH = 1000
|
|
|
|
| 182 |
|
| 183 |
@st.cache_resource
|
| 184 |
def load_model(self):
|
| 185 |
+
"""Load FLAN-T5 Small model with optimized settings"""
|
| 186 |
if self.model is None:
|
| 187 |
try:
|
| 188 |
self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
|
|
|
| 199 |
return True
|
| 200 |
|
| 201 |
def generate_answer(self, question: str, context: str, max_length: int = 512) -> str:
|
| 202 |
+
"""Generate answer with FLAN-T5 optimized parameters"""
|
| 203 |
if not self.load_model():
|
| 204 |
return "Error: Model loading failed. Please try again later."
|
| 205 |
|
| 206 |
try:
|
| 207 |
+
# FLAN-T5 responds better to direct instruction prompts
|
| 208 |
+
input_text = f"""Answer the following question about autism using the provided research context.
|
| 209 |
+
Research Context:
|
| 210 |
+
{context}
|
| 211 |
+
|
| 212 |
+
Question: {question}
|
| 213 |
+
|
| 214 |
+
Instructions:
|
| 215 |
+
- Be specific and evidence-based
|
| 216 |
+
- Use clear, accessible language
|
| 217 |
+
- Focus on practical implications
|
| 218 |
+
- Cite research when relevant
|
| 219 |
+
- Be respectful of neurodiversity
|
| 220 |
+
|
| 221 |
+
Answer:"""
|
| 222 |
|
| 223 |
inputs = self.tokenizer(
|
| 224 |
input_text,
|
|
|
|
| 232 |
outputs = self.model.generate(
|
| 233 |
**inputs,
|
| 234 |
max_length=max_length,
|
| 235 |
+
min_length=100, # Reduzido para FLAN-T5 Small
|
| 236 |
+
num_beams=3, # Ajustado para melhor performance
|
| 237 |
+
length_penalty=1.0, # Mais neutro para respostas concisas
|
| 238 |
+
temperature=0.6, # Mais determinístico
|
| 239 |
+
repetition_penalty=1.2,
|
| 240 |
early_stopping=True,
|
| 241 |
+
no_repeat_ngram_size=2,
|
| 242 |
do_sample=True,
|
| 243 |
+
top_k=30,
|
| 244 |
+
top_p=0.92
|
| 245 |
)
|
| 246 |
|
| 247 |
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 248 |
response = TextProcessor.clean_text(response)
|
| 249 |
|
| 250 |
+
if len(response.strip()) < 50: # Ajustado para respostas mais curtas do FLAN-T5
|
| 251 |
return self._get_fallback_response()
|
| 252 |
|
| 253 |
return self._format_response(response)
|
|
|
|
| 256 |
logging.error(f"Error generating response: {str(e)}")
|
| 257 |
return "Error: Could not generate response. Please try again."
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
@staticmethod
|
| 260 |
def _get_fallback_response() -> str:
|
| 261 |
"""Provide a structured fallback response"""
|