Spaces:
Sleeping
Sleeping
import logging | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import streamlit as st | |
from agno.agent import Agent | |
from agno.tools.arxiv import ArxivTools | |
from agno.tools.pubmed import PubmedTools | |
from agno.models.base import Model | |
from tenacity import retry, stop_after_attempt, wait_exponential | |
import time | |
import datetime | |
MODEL_PATH = "google/flan-t5-small" | |
# Simple Response class to wrap the model output | |
class Response: | |
def __init__(self, content): | |
# Ensure content is a string and not empty | |
if content is None: | |
content = "" | |
if not isinstance(content, str): | |
content = str(content) | |
# Store the content | |
self.content = content | |
# Add tool_calls attribute with default empty list | |
self.tool_calls = [] | |
# Add other attributes that might be needed | |
self.audio = None | |
self.images = [] | |
self.citations = [] | |
self.metadata = {} | |
self.finish_reason = "stop" | |
self.usage = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} | |
# Add timestamp attributes | |
current_time = time.time() | |
self.created_at = int(current_time) # Convert to integer | |
self.created = int(current_time) | |
self.timestamp = datetime.datetime.now().isoformat() | |
# Add model info attributes | |
self.id = "local-model-response" | |
self.model = "local-huggingface" | |
self.object = "chat.completion" | |
self.choices = [{"index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop"}] | |
# Add additional attributes that might be needed | |
self.system_fingerprint = "" | |
self.is_truncated = False | |
self.role = "assistant" | |
def __str__(self): | |
return self.content if self.content else "" | |
def __repr__(self): | |
return f"Response(content='{self.content[:50]}{'...' if len(self.content) > 50 else ''}')" | |
# Personnalized class for local models | |
class LocalHuggingFaceModel(Model): | |
def __init__(self, model, tokenizer, max_length=512): | |
super().__init__(id="local-huggingface") | |
self.model = model | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
async def ainvoke(self, prompt: str, **kwargs) -> str: | |
"""Async invoke method""" | |
return await self.invoke(prompt=prompt, **kwargs) | |
async def ainvoke_stream(self, prompt: str, **kwargs): | |
"""Async streaming invoke method""" | |
result = await self.invoke(prompt=prompt, **kwargs) | |
yield result | |
def invoke(self, prompt: str, **kwargs) -> str: | |
"""Synchronous invoke method""" | |
try: | |
logging.info(f"Invoking model with prompt: {prompt[:100] if prompt else 'None'}...") | |
# Check if prompt is None or empty | |
if prompt is None: | |
logging.warning("None prompt provided to invoke method") | |
return Response("No input provided. Please provide a valid prompt.") | |
if not prompt.strip(): | |
logging.warning("Empty prompt provided to invoke method") | |
return Response("No input provided. Please provide a non-empty prompt.") | |
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) | |
# Configure generation parameters | |
generation_config = { | |
"max_length": self.max_length, | |
"num_return_sequences": 1, | |
"do_sample": kwargs.get("do_sample", False), | |
"temperature": kwargs.get("temperature", 1.0), | |
"top_p": kwargs.get("top_p", 1.0), | |
} | |
# Generate the answer | |
outputs = self.model.generate(**inputs, **generation_config) | |
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Check if output is empty | |
if not decoded_output or not decoded_output.strip(): | |
logging.warning("Model generated empty output") | |
return Response("The model did not generate any output. Please try with a different prompt.") | |
logging.info(f"Model generated output: {decoded_output[:100]}...") | |
return Response(decoded_output) | |
except Exception as e: | |
logging.error(f"Error in local model generation: {str(e)}") | |
if hasattr(e, 'args') and len(e.args) > 0: | |
error_message = e.args[0] | |
else: | |
error_message = str(e) | |
return Response(f"Error during generation: {error_message}") | |
def invoke_stream(self, prompt: str, **kwargs): | |
"""Synchronous streaming invoke method""" | |
result = self.invoke(prompt=prompt, **kwargs) | |
yield result | |
def parse_provider_response(self, response: str) -> str: | |
"""Parse the provider response""" | |
return response | |
def parse_provider_response_delta(self, delta: str) -> str: | |
"""Parse the provider response delta for streaming""" | |
return delta | |
async def aresponse(self, prompt=None, **kwargs): | |
"""Async response method - required abstract method""" | |
if prompt is None: | |
prompt = kwargs.get('input', '') | |
content = await self.ainvoke(prompt=prompt, **kwargs) | |
return Response(content) | |
async def aresponse_stream(self, prompt=None, **kwargs): | |
"""Async streaming response method - required abstract method""" | |
if prompt is None: | |
prompt = kwargs.get('input', '') | |
async for chunk in self.ainvoke_stream(prompt=prompt, **kwargs): | |
yield Response(chunk) | |
def response(self, prompt=None, **kwargs): | |
"""Synchronous response method - required abstract method""" | |
if prompt is None: | |
prompt = kwargs.get('input', '') | |
content = self.invoke(prompt=prompt, **kwargs) | |
return Response(content) | |
def response_stream(self, prompt=None, **kwargs): | |
"""Synchronous streaming response method - required abstract method""" | |
if prompt is None: | |
prompt = kwargs.get('input', '') | |
for chunk in self.invoke_stream(prompt=prompt, **kwargs): | |
yield Response(chunk) | |
def generate(self, prompt: str, **kwargs): | |
try: | |
inputs = self.tokenizer(prompt, return_tensors="pt", padding=True) | |
# Configure generation parameters | |
generation_config = { | |
"max_length": self.max_length, | |
"num_return_sequences": 1, | |
"do_sample": kwargs.get("do_sample", False), | |
"temperature": kwargs.get("temperature", 1.0), | |
"top_p": kwargs.get("top_p", 1.0), | |
} | |
# Generate the answer | |
outputs = self.model.generate(**inputs, **generation_config) | |
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return decoded_output | |
except Exception as e: | |
logging.error(f"Error in local model generation: {str(e)}") | |
if hasattr(e, 'args') and len(e.args) > 0: | |
error_message = e.args[0] | |
else: | |
error_message = str(e) | |
return f"Error during generation: {error_message}" | |
class DummyModel(Model): | |
def __init__(self): | |
super().__init__(id="dummy-model") | |
async def ainvoke(self, prompt: str, **kwargs) -> str: | |
"""Async invoke method""" | |
return await self.invoke(prompt=prompt, **kwargs) | |
async def ainvoke_stream(self, prompt: str, **kwargs): | |
"""Async streaming invoke method""" | |
result = await self.invoke(prompt=prompt, **kwargs) | |
yield result | |
def invoke(self, prompt: str, **kwargs) -> str: | |
"""Synchronous invoke method""" | |
return Response("Sorry, the model is not available. Please try again later.") | |
def invoke_stream(self, prompt: str, **kwargs): | |
"""Synchronous streaming invoke method""" | |
result = self.invoke(prompt=prompt, **kwargs) | |
yield result | |
def parse_provider_response(self, response: str) -> str: | |
"""Parse the provider response""" | |
return response | |
def parse_provider_response_delta(self, delta: str) -> str: | |
"""Parse the provider response delta for streaming""" | |
return delta | |
async def aresponse(self, prompt=None, **kwargs): | |
"""Async response method - required abstract method""" | |
if prompt is None: | |
prompt = kwargs.get('input', '') | |
content = await self.ainvoke(prompt=prompt, **kwargs) | |
return Response(content) | |
async def aresponse_stream(self, prompt=None, **kwargs): | |
"""Async streaming response method - required abstract method""" | |
if prompt is None: | |
prompt = kwargs.get('input', '') | |
async for chunk in self.ainvoke_stream(prompt=prompt, **kwargs): | |
yield Response(chunk) | |
def response(self, prompt=None, **kwargs): | |
"""Synchronous response method - required abstract method""" | |
if prompt is None: | |
prompt = kwargs.get('input', '') | |
content = self.invoke(prompt=prompt, **kwargs) | |
return Response(content) | |
def response_stream(self, prompt=None, **kwargs): | |
"""Synchronous streaming response method - required abstract method""" | |
if prompt is None: | |
prompt = kwargs.get('input', '') | |
for chunk in self.invoke_stream(prompt=prompt, **kwargs): | |
yield Response(chunk) | |
class ModelHandler: | |
def __init__(self): | |
"""Initialize the model handler""" | |
self.model = None | |
self.tokenizer = None | |
self.translator = None | |
self.researcher = None | |
self.summarizer = None | |
self.presenter = None | |
self._initialize_model() | |
def _initialize_model(self): | |
"""Initialize model and tokenizer""" | |
self.model, self.tokenizer = self._load_model() | |
# Using local model as fallback | |
base_model = self._initialize_local_model() | |
self.translator = Agent( | |
name="Translator", | |
role="You will translate the query to English", | |
model=base_model, | |
goal="Translate to English", | |
instructions=[ | |
"Translate the query to English" | |
] | |
) | |
self.researcher = Agent( | |
name="Researcher", | |
role="You are a research scholar who specializes in autism research.", | |
model=base_model, | |
tools=[ArxivTools(), PubmedTools()], | |
instructions=[ | |
"You need to understand the context of the question to provide the best answer based on your tools.", | |
"Be precise and provide just enough information to be useful", | |
"You must cite the sources used in your answer.", | |
"You must create an accessible summary.", | |
"The content must be for people without autism knowledge.", | |
"Focus in the main findings of the paper taking in consideration the question.", | |
"The answer must be brief." | |
], | |
show_tool_calls=True, | |
) | |
self.summarizer = Agent( | |
name="Summarizer", | |
role="You are a specialist in summarizing research papers for people without autism knowledge.", | |
model=base_model, | |
instructions=[ | |
"You must provide just enough information to be useful", | |
"You must cite the sources used in your answer.", | |
"You must be clear and concise.", | |
"You must create an accessible summary.", | |
"The content must be for people without autism knowledge.", | |
"Focus in the main findings of the paper taking in consideration the question.", | |
"The answer must be brief.", | |
"Remove everything related to the run itself like: 'Running: transfer_', just use plain text", | |
"You must use the language provided by the user to present the results.", | |
"Add references to the sources used in the answer.", | |
"Add emojis to make the presentation more interactive.", | |
"Translate the answer to Portuguese." | |
], | |
show_tool_calls=True, | |
markdown=True, | |
add_references=True, | |
) | |
self.presenter = Agent( | |
name="Presenter", | |
role="You are a professional researcher who presents the results of the research.", | |
model=base_model, | |
instructions=[ | |
"You are multilingual", | |
"You must present the results in a clear and concise manner.", | |
"Cleanup the presentation to make it more readable.", | |
"Remove unnecessary information.", | |
"Remove everything related to the run itself like: 'Running: transfer_', just use plain text", | |
"You must use the language provided by the user to present the results.", | |
"Add references to the sources used in the answer.", | |
"Add emojis to make the presentation more interactive.", | |
"Translate the answer to Portuguese." | |
], | |
add_references=True, | |
) | |
def _format_prompt(self, role, instructions, query): | |
"""Format the prompt for the model""" | |
# Validate inputs | |
if not role or not role.strip(): | |
role = "Assistant" | |
logging.warning("Empty role provided to _format_prompt, using default: 'Assistant'") | |
if not instructions or not instructions.strip(): | |
instructions = "Please process the following input." | |
logging.warning("Empty instructions provided to _format_prompt, using default instructions") | |
if not query or not query.strip(): | |
query = "No input provided." | |
logging.warning("Empty query provided to _format_prompt, using placeholder text") | |
# Format the prompt | |
formatted_prompt = f"""Task: {role} | |
Instructions: | |
{instructions} | |
Input: {query} | |
Output:""" | |
# Ensure the prompt is not empty | |
if not formatted_prompt or not formatted_prompt.strip(): | |
logging.error("Generated an empty prompt despite validation") | |
formatted_prompt = "Please provide a response." | |
return formatted_prompt | |
def _load_model(): | |
"""Load the model and tokenizer with retry logic""" | |
# Define retry decorator for model loading | |
def load_with_retry(model_path): | |
try: | |
logging.info(f"Attempting to load model from {model_path}") | |
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir="./model_cache") | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
model_path, | |
device_map="cpu", | |
low_cpu_mem_usage=True, | |
cache_dir="./model_cache" | |
) | |
logging.info(f"Successfully loaded model from {model_path}") | |
return model, tokenizer | |
except Exception as e: | |
logging.error(f"Error loading model from {model_path}: {str(e)}") | |
raise e | |
# Try primary model first | |
try: | |
return load_with_retry(MODEL_PATH) | |
except Exception as primary_error: | |
logging.error(f"Failed to load primary model ({MODEL_PATH}): {str(primary_error)}") | |
# Try fallback models | |
fallback_models = [ | |
"google/flan-t5-base", | |
"google/flan-t5-small", | |
"facebook/bart-base", | |
"t5-small" | |
] | |
for fallback_model in fallback_models: | |
if fallback_model != MODEL_PATH: # Skip if it's the same as the primary model | |
try: | |
logging.info(f"Trying fallback model: {fallback_model}") | |
return load_with_retry(fallback_model) | |
except Exception as fallback_error: | |
logging.error(f"Failed to load fallback model ({fallback_model}): {str(fallback_error)}") | |
# If all models fail, try a final tiny model | |
try: | |
logging.info("Trying final fallback to t5-small") | |
return load_with_retry("t5-small") | |
except Exception as final_error: | |
logging.error(f"All model loading attempts failed. Final error: {str(final_error)}") | |
st.error("Failed to load any model. Please check your internet connection and try again.") | |
return None, None | |
def _initialize_local_model(self): | |
"""Initialize local model as fallback""" | |
if self.model is None or self.tokenizer is None: | |
self.model, self.tokenizer = self._load_model() | |
if self.model is None or self.tokenizer is None: | |
# Create a dummy model that returns a helpful message | |
logging.error("Failed to load any model. Creating a dummy model.") | |
return DummyModel() | |
# Create a LocalHuggingFaceModel instance compatible with Agno | |
return LocalHuggingFaceModel(self.model, self.tokenizer, max_length=512) | |
def generate_answer(self, query: str) -> str: | |
try: | |
logging.info(f"Generating answer for query: {query}") | |
# Validate input query | |
if not query or not query.strip(): | |
logging.error("Empty query provided") | |
return "Error: Please provide a non-empty query" | |
# Check if models are available | |
if isinstance(self.translator, DummyModel) or isinstance(self.researcher, DummyModel) or \ | |
isinstance(self.summarizer, DummyModel) or isinstance(self.presenter, DummyModel): | |
logging.error("One or more models are not available") | |
return """ | |
# 🚨 Serviço Temporariamente Indisponível 🚨 | |
Desculpe, estamos enfrentando problemas de conexão com nossos serviços de modelo de linguagem. | |
## Possíveis causas: | |
- Problemas de conexão com a internet | |
- Servidores do Hugging Face podem estar sobrecarregados ou temporariamente indisponíveis | |
- Limitações de recursos do sistema | |
## O que você pode fazer: | |
- Tente novamente mais tarde | |
- Verifique sua conexão com a internet | |
- Entre em contato com o suporte se o problema persistir | |
Agradecemos sua compreensão! | |
""" | |
# Format translation prompt | |
translation_prompt = self._format_prompt( | |
role="Translate the following text to English", | |
instructions="Provide a direct English translation of the input text.", | |
query=query | |
) | |
logging.info(f"Translation prompt: {translation_prompt}") | |
# Validate translation prompt | |
if not translation_prompt or not translation_prompt.strip(): | |
logging.error("Empty translation prompt generated") | |
return "Error: Unable to generate translation prompt" | |
# Get English translation | |
translation = self.translator.run(prompt=translation_prompt, stream=False) | |
logging.info(f"Translation result type: {type(translation)}") | |
logging.info(f"Translation result: {translation}") | |
if not translation: | |
logging.error("Translation failed") | |
return "Error: Unable to translate the query" | |
if hasattr(translation, 'content'): | |
translation_content = translation.content | |
logging.info(f"Translation content: {translation_content}") | |
else: | |
translation_content = str(translation) | |
logging.info(f"Translation as string: {translation_content}") | |
# Validate translation content | |
if not translation_content or not translation_content.strip(): | |
logging.error("Empty translation content") | |
return "Error: Empty translation result" | |
# Format research prompt | |
research_prompt = self._format_prompt( | |
role="Research Assistant", | |
instructions="Provide a clear and concise answer based on scientific sources.", | |
query=translation_content | |
) | |
logging.info(f"Research prompt: {research_prompt}") | |
# Validate research prompt | |
if not research_prompt or not research_prompt.strip(): | |
logging.error("Empty research prompt generated") | |
return "Error: Unable to generate research prompt" | |
# Get research results | |
research_results = self.researcher.run(prompt=research_prompt, stream=False) | |
logging.info(f"Research results type: {type(research_results)}") | |
logging.info(f"Research results: {research_results}") | |
if not research_results: | |
logging.error("Research failed") | |
return "Error: Unable to perform research" | |
if hasattr(research_results, 'content'): | |
research_content = research_results.content | |
logging.info(f"Research content: {research_content}") | |
else: | |
research_content = str(research_results) | |
logging.info(f"Research as string: {research_content}") | |
# Validate research content | |
if not research_content or not research_content.strip(): | |
logging.error("Empty research content") | |
return "Error: Empty research result" | |
logging.info(f"Research results: {research_results}") | |
# Format summary prompt | |
summary_prompt = self._format_prompt( | |
role="Summary Assistant", | |
instructions="Provide a clear and concise summary of the research results.", | |
query=research_content | |
) | |
logging.info(f"Summary prompt: {summary_prompt}") | |
# Validate summary prompt | |
if not summary_prompt or not summary_prompt.strip(): | |
logging.error("Empty summary prompt generated") | |
return "Error: Unable to generate summary prompt" | |
# Get summary | |
summary = self.summarizer.run(prompt=summary_prompt, stream=False) | |
logging.info(f"Summary type: {type(summary)}") | |
logging.info(f"Summary: {summary}") | |
if not summary: | |
logging.error("Summary failed") | |
return "Error: Unable to generate summary" | |
if hasattr(summary, 'content'): | |
summary_content = summary.content | |
logging.info(f"Summary content: {summary_content}") | |
else: | |
summary_content = str(summary) | |
logging.info(f"Summary as string: {summary_content}") | |
# Validate summary content | |
if not summary_content or not summary_content.strip(): | |
logging.error("Empty summary content") | |
return "Error: Empty summary result" | |
logging.info(f"Summary: {summary}") | |
# Format presentation prompt | |
presentation_prompt = self._format_prompt( | |
role="Presentation Assistant", | |
instructions="Provide a clear and concise presentation of the research results.", | |
query=summary_content | |
) | |
logging.info(f"Presentation prompt: {presentation_prompt}") | |
# Validate presentation prompt | |
if not presentation_prompt or not presentation_prompt.strip(): | |
logging.error("Empty presentation prompt generated") | |
return "Error: Unable to generate presentation prompt" | |
# Get presentation | |
presentation = self.presenter.run(prompt=presentation_prompt, stream=False) | |
logging.info(f"Presentation type: {type(presentation)}") | |
logging.info(f"Presentation: {presentation}") | |
if not presentation: | |
logging.error("Presentation failed") | |
return "Error: Unable to generate presentation" | |
if hasattr(presentation, 'content'): | |
presentation_content = presentation.content | |
logging.info(f"Presentation content: {presentation_content}") | |
# Check if content is empty or just whitespace | |
if not presentation_content.strip(): | |
logging.error("Presentation content is empty or whitespace") | |
return "Error: Empty presentation content" | |
return presentation_content | |
else: | |
presentation_str = str(presentation) | |
logging.info(f"Presentation as string: {presentation_str}") | |
# Check if content is empty or just whitespace | |
if not presentation_str.strip(): | |
logging.error("Presentation string is empty or whitespace") | |
return "Error: Empty presentation string" | |
return presentation_str | |
except Exception as e: | |
logging.error(f"Error generating answer: {str(e)}") | |
if hasattr(e, 'args') and len(e.args) > 0: | |
error_message = e.args[0] | |
else: | |
error_message = str(e) | |
return f"Error: {error_message}" |