Spaces:
Sleeping
Sleeping
import torch | |
from transformers import BertForSequenceClassification, BertTokenizer | |
import numpy as np | |
import re | |
from datetime import datetime | |
import os | |
import logging | |
from typing import Tuple, Dict, Any | |
import json | |
# import pyttsx3 # Comentado para no usar TTS en el servidor | |
class MentalHealthChatbot: | |
def __init__(self, model_path: str = 'models/bert_emotion_model'): | |
""" | |
Inicializa el chatbot con el modelo BERT fine-tuned y configuraciones necesarias. | |
Args: | |
model_path: Ruta al modelo fine-tuned | |
""" | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Configuraci贸n del logging | |
self.logger = logging.getLogger(__name__) | |
self.logger.setLevel(logging.INFO) | |
# Archivo de log en /tmp | |
handler = logging.FileHandler('/tmp/chatbot.log') | |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
if not self.logger.handlers: | |
self.logger.addHandler(handler) | |
# Rutas de carpetas en /tmp para evitar permisos de solo lectura en /app | |
self.conversations_dir = "/tmp/conversations" | |
self.audio_dir = "/tmp/audio" | |
try: | |
self.logger.info("Cargando el tokenizador y el modelo BERT fine-tuned...") | |
# Crear carpeta de conversaciones en /tmp | |
os.makedirs(self.conversations_dir, exist_ok=True) | |
self.tokenizer = BertTokenizer.from_pretrained(model_path) | |
self.model = BertForSequenceClassification.from_pretrained(model_path).to(self.device) | |
# Cargar respuestas predefinidas | |
self.load_responses() | |
# Inicializar el historial de conversaci贸n | |
self.conversation_history = [] | |
self.logger.info("Chatbot inicializado correctamente.") | |
except Exception as e: | |
self.logger.error(f"Error al cargar el modelo: {str(e)}") | |
raise e | |
def load_responses(self): | |
"""Carga las respuestas predefinidas desde un archivo JSON.""" | |
try: | |
with open('models/responses.json', 'r', encoding='utf-8') as f: | |
self.responses = json.load(f) | |
self.logger.info("Respuestas cargadas desde 'responses.json'.") | |
except FileNotFoundError: | |
self.logger.error("Archivo 'responses.json' no encontrado. Aseg煤rate de que el archivo existe en la ruta especificada.") | |
raise | |
except json.JSONDecodeError as e: | |
self.logger.error(f"Error al decodificar 'responses.json': {str(e)}") | |
raise | |
def preprocess_text(self, text: str) -> str: | |
"""Preprocesa el texto de entrada.""" | |
try: | |
text = text.lower() | |
text = re.sub(r'[^\w\s]', '', text) | |
return text.strip() | |
except Exception as e: | |
self.logger.error(f"Error al preprocesar el texto: {str(e)}") | |
return text | |
def detect_emergency(self, text: str) -> bool: | |
"""Detecta si el mensaje indica una emergencia de salud mental.""" | |
try: | |
emergency_keywords = [ | |
'suicidar', 'morir', 'muerte', 'matar', 'dolor', | |
'ayuda', 'emergencia', 'crisis', 'grave' | |
] | |
return any(keyword in text.lower() for keyword in emergency_keywords) | |
except Exception as e: | |
self.logger.error(f"Error al detectar emergencia: {str(e)}") | |
return False | |
def get_emotion_prediction(self, text: str) -> Tuple[str, float]: | |
"""Predice la emoci贸n del texto usando el modelo fine-tuned.""" | |
emotion_labels = [ | |
'FELICIDAD', 'NEUTRAL', 'DEPRESI脫N', 'ANSIEDAD', 'ESTR脡S', | |
'EMERGENCIA', 'CONFUSI脫N', 'IRA', 'MIEDO', 'SORPRESA', 'DISGUSTO' | |
] | |
try: | |
inputs = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=128, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probs = torch.softmax(outputs.logits, dim=1) | |
predicted_class = torch.argmax(probs, dim=1).item() | |
confidence = probs[0][predicted_class].item() | |
emotion = emotion_labels[predicted_class] | |
self.logger.info(f"Emoci贸n predicha: {emotion} con confianza {confidence:.2f}") | |
return emotion, confidence | |
except Exception as e: | |
self.logger.error(f"Error en la predicci贸n de emoci贸n: {str(e)}") | |
return 'CONFUSI脫N', 0.0 | |
def generate_response(self, user_input: str) -> Dict[str, Any]: | |
"""Genera una respuesta basada en el input del usuario, sin generar audio en el servidor.""" | |
try: | |
processed_text = self.preprocess_text(user_input) | |
self.logger.info(f"Texto procesado: {processed_text}") | |
if self.detect_emergency(processed_text): | |
emotion = 'EMERGENCIA' | |
confidence = 1.0 | |
self.logger.info("Emergencia detectada en el mensaje del usuario.") | |
else: | |
emotion, confidence = self.get_emotion_prediction(processed_text) | |
responses = self.responses.get(emotion, self.responses.get('CONFUSI脫N', ["Lo siento, no he entendido tu mensaje."])) | |
response = np.random.choice(responses) | |
self.logger.info(f"Respuesta seleccionada: {response}") | |
# Comentamos la generaci贸n de audio en el servidor | |
# audio_path = self.generate_audio(response) | |
self.update_conversation_history(user_input, response, emotion) | |
self.save_conversation_history() | |
return { | |
'text': response, | |
# 'audio_path': audio_path, # Comentado | |
'emotion': emotion, | |
'confidence': confidence, | |
'timestamp': datetime.now().isoformat() | |
} | |
except Exception as e: | |
self.logger.error(f"Error al generar respuesta: {str(e)}") | |
return { | |
'text': "Lo siento, ha ocurrido un error. 驴Podr铆as intentarlo de nuevo?", | |
# 'audio_path': None, | |
'emotion': 'ERROR', | |
'confidence': 0.0, | |
'timestamp': datetime.now().isoformat() | |
} | |
# Comentamos toda la funci贸n generate_audio si no se usa | |
""" | |
def generate_audio(self, text: str) -> str: | |
# Genera el audio en el servidor (COMENTADO). | |
try: | |
filename = f"response_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}.mp3" | |
file_path = os.path.join(self.audio_dir, filename) | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
engine = pyttsx3.init() | |
voices = engine.getProperty('voices') | |
for voice in voices: | |
if 'Spanish' in voice.name or 'Espa帽ol' in voice.name: | |
engine.setProperty('voice', voice.id) | |
break | |
else: | |
self.logger.warning("No se encontr贸 una voz en espa帽ol. Usando la voz predeterminada.") | |
rate = engine.getProperty('rate') | |
engine.setProperty('rate', rate - 50) | |
engine.save_to_file(text, file_path) | |
engine.runAndWait() | |
self.logger.info(f"Audio generado y guardado en {file_path}") | |
return file_path | |
except Exception as e: | |
self.logger.error(f"Error al generar audio: {str(e)}") | |
return None | |
""" | |
def update_conversation_history(self, user_input: str, response: str, emotion: str): | |
"""Actualiza el historial de conversaci贸n en memoria.""" | |
try: | |
self.conversation_history.append({ | |
'user_input': user_input, | |
'response': response, | |
'emotion': emotion, | |
'timestamp': datetime.now().isoformat() | |
}) | |
if len(self.conversation_history) > 10: | |
self.conversation_history.pop(0) | |
self.logger.info("Historial de conversaci贸n actualizado.") | |
except Exception as e: | |
self.logger.error(f"Error al actualizar el historial de conversaci贸n: {str(e)}") | |
def save_conversation_history(self): | |
"""Guarda el historial de conversaci贸n en un archivo dentro de /tmp.""" | |
try: | |
filename = f"{self.conversations_dir}/chat_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" | |
os.makedirs(os.path.dirname(filename), exist_ok=True) | |
with open(filename, 'w', encoding='utf-8') as f: | |
json.dump(self.conversation_history, f, ensure_ascii=False, indent=2) | |
self.logger.info(f"Historial de conversaci贸n guardado en {filename}") | |
except Exception as e: | |
self.logger.error(f"Error al guardar el historial: {str(e)}") |