Spaces:
Sleeping
Sleeping
import logging | |
import torch | |
from transformers import AutoTokenizer, T5ForConditionalGeneration | |
import streamlit as st | |
from utils.text_processor import TextProcessor | |
from typing import List | |
MODEL_PATH = "google/flan-t5-small" | |
class ModelHandler: | |
def __init__(self): | |
"""Initialize the model handler""" | |
self.model = None | |
self.tokenizer = None | |
self._initialize_model() | |
def _initialize_model(self): | |
"""Initialize model and tokenizer""" | |
self.model, self.tokenizer = self._load_model() | |
def _load_model(): | |
"""Load the T5 model and tokenizer""" | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
model = T5ForConditionalGeneration.from_pretrained(MODEL_PATH) | |
return model, tokenizer | |
except Exception as e: | |
logging.error(f"Error loading model: {str(e)}") | |
return None, None | |
def generate_answer(self, query: str, context: str) -> str: | |
""" | |
Generate an answer based on the research papers context | |
""" | |
base_knowledge = """ | |
Autism, or Autism Spectrum Disorder (ASD), is a complex neurodevelopmental condition that affects how a person perceives and interacts with the world. Key aspects include: | |
1. Social communication and interaction | |
2. Repetitive behaviors and specific interests | |
3. Sensory sensitivities | |
4. Varying levels of support needs | |
5. Early developmental differences | |
6. Unique strengths and challenges | |
The condition exists on a spectrum, meaning each person's experience is unique. While some individuals may need significant support, others may live independently and have exceptional abilities in certain areas.""" | |
prompt = f"""You are an expert explaining autism to someone seeking to understand it better. Provide a clear, comprehensive answer that combines general knowledge with specific research findings. | |
QUESTION: | |
{query} | |
GENERAL KNOWLEDGE: | |
{base_knowledge} | |
RECENT RESEARCH FINDINGS: | |
{context} | |
Instructions for your response: | |
1. Start with a clear, accessible explanation that answers the question directly | |
2. Use everyday language while maintaining accuracy | |
3. Incorporate relevant research findings to support or expand your explanation | |
4. When citing research, use "According to recent research..." or "A study found..." | |
5. Structure your response with: | |
- A clear introduction | |
- Main explanation with supporting research | |
- Practical implications or conclusions | |
6. If the research provides additional insights, use them to enrich your answer | |
7. Acknowledge if certain aspects aren't covered by the available research | |
FORMAT: | |
- Use clear paragraphs | |
- Explain technical terms | |
- Be conversational but informative | |
- Include specific examples when helpful | |
Please provide your comprehensive answer:""" | |
try: | |
response = self.generate( | |
prompt, | |
max_length=1000, | |
temperature=0.7, | |
)[0] | |
# Clean up the response | |
response = response.replace("Answer:", "").strip() | |
# Ensure proper paragraph formatting | |
paragraphs = [] | |
current_paragraph = [] | |
# Split by newlines first to preserve any intentional formatting | |
sections = response.split('\n') | |
for section in sections: | |
if not section.strip(): | |
if current_paragraph: | |
paragraphs.append(' '.join(current_paragraph)) | |
current_paragraph = [] | |
else: | |
# Split long paragraphs into more readable chunks | |
sentences = section.split('. ') | |
for sentence in sentences: | |
current_paragraph.append(sentence) | |
if len(' '.join(current_paragraph)) > 200: # Break long paragraphs | |
paragraphs.append('. '.join(current_paragraph) + '.') | |
current_paragraph = [] | |
if current_paragraph: | |
paragraphs.append('. '.join(current_paragraph) + '.') | |
# Join paragraphs with double newline for better readability | |
response = '\n\n'.join(paragraphs) | |
return response | |
except Exception as e: | |
logging.error(f"Error generating answer: {str(e)}") | |
return "I apologize, but I encountered an error while generating the answer. Please try again or rephrase your question." | |
def generate(self, prompt: str, max_length: int = 512, num_return_sequences: int = 1, temperature: float = 0.7) -> List[str]: | |
""" | |
Generate text using the T5 model | |
""" | |
try: | |
# Encode the prompt | |
inputs = self.tokenizer( | |
prompt, | |
return_tensors="pt", | |
max_length=max_length, | |
truncation=True, | |
padding=True | |
) | |
# Generate response | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=max_length, | |
num_return_sequences=num_return_sequences, | |
temperature=temperature, | |
do_sample=True, | |
top_p=0.95, | |
top_k=50, | |
no_repeat_ngram_size=3, | |
early_stopping=True | |
) | |
# Decode and return the generated text | |
decoded_outputs = [ | |
self.tokenizer.decode(output, skip_special_tokens=True) | |
for output in outputs | |
] | |
return decoded_outputs | |
except Exception as e: | |
logging.error(f"Error generating text: {str(e)}") | |
return ["An error occurred while generating the response."] | |
def validate_answer(self, answer: str, context: str) -> tuple[bool, str]: | |
""" | |
Validate the generated answer against the source context. | |
Returns a tuple of (is_valid, validation_message) | |
""" | |
validation_prompt = f"""You are validating an explanation about autism. Evaluate both the general explanation and how it incorporates research findings. | |
ANSWER TO VALIDATE: | |
{answer} | |
RESEARCH CONTEXT: | |
{context} | |
EVALUATION CRITERIA: | |
1. Accuracy of General Information: | |
- Basic autism concepts explained correctly | |
- Clear and accessible language | |
- Balanced perspective | |
2. Research Integration: | |
- Research findings used appropriately | |
- No misrepresentation of studies | |
- Proper balance of general knowledge and research findings | |
3. Explanation Quality: | |
- Clear and logical structure | |
- Technical terms explained | |
- Helpful examples or illustrations | |
RESPOND IN THIS FORMAT: | |
--- | |
VALID: [true/false] | |
STRENGTHS: [list main strengths] | |
CONCERNS: [list any issues] | |
VERDICT: [final assessment] | |
--- | |
Example Response: | |
--- | |
VALID: true | |
STRENGTHS: | |
- Clear explanation of autism fundamentals | |
- Research findings well integrated | |
- Technical terms properly explained | |
CONCERNS: | |
- Minor: Could include more practical examples | |
VERDICT: The answer provides an accurate and well-supported explanation that effectively combines general knowledge with research findings. | |
--- | |
YOUR EVALUATION:""" | |
try: | |
validation_result = self.generate( | |
validation_prompt, | |
max_length=300, | |
temperature=0.3 | |
)[0] | |
# Extract content between dashes | |
parts = validation_result.split('---') | |
if len(parts) >= 3: | |
content = parts[1].strip() | |
# Parse the structured content | |
lines = content.split('\n') | |
valid_line = next((line for line in lines if line.startswith('VALID:')), '') | |
verdict_line = next((line for line in lines if line.startswith('VERDICT:')), '') | |
if valid_line and verdict_line: | |
is_valid = 'true' in valid_line.lower() | |
verdict = verdict_line.replace('VERDICT:', '').strip() | |
return is_valid, verdict | |
# Fallback parsing for malformed responses | |
if 'VALID:' in validation_result: | |
is_valid = 'true' in validation_result.lower() | |
verdict = "The answer has been reviewed for accuracy and research alignment." | |
return is_valid, verdict | |
logging.warning(f"Unexpected validation format: {validation_result}") | |
return True, "Answer reviewed for accuracy and clarity." | |
except Exception as e: | |
logging.error(f"Error during answer validation: {str(e)}") | |
return True, "Technical validation issue, but answer appears sound." | |
def _get_fallback_response() -> str: | |
"""Provide a friendly, helpful fallback response""" | |
return """I apologize, but I couldn't find enough specific research to properly answer your question. To help you get better information, you could: | |
• Ask about specific aspects of autism you're interested in | |
• Focus on particular topics like: | |
- Early signs and diagnosis | |
- Treatment approaches | |
- Latest research findings | |
- Support strategies | |
This will help me provide more detailed, research-backed information that's relevant to your interests.""" | |
def _format_response(response: str) -> str: | |
"""Format the response to be more readable and engaging""" | |
# Clean up the response | |
response = response.replace(" 1.", "\n\n1.") | |
response = response.replace(" 2.", "\n2.") | |
response = response.replace(" 3.", "\n3.") | |
# Split into paragraphs for better readability | |
paragraphs = response.split('\n\n') | |
formatted_paragraphs = [] | |
for paragraph in paragraphs: | |
# Format citations to stand out | |
if "According to" in paragraph or "Research" in paragraph: | |
paragraph = f"*{paragraph}*" | |
# Add bullet points for lists | |
if paragraph.strip().startswith(('1.', '2.', '3.')): | |
paragraph = paragraph.replace('1.', '•') | |
paragraph = paragraph.replace('2.', '•') | |
paragraph = paragraph.replace('3.', '•') | |
formatted_paragraphs.append(paragraph) | |
return '\n\n'.join(formatted_paragraphs) | |