BioGPT-chatbot / app.py
Blaiseboy's picture
Upload app.py
eda861e verified
raw
history blame
28.2 kB
# BioGPT Medical Chatbot with Gradio Interface - HUGGING FACE SPACES VERSION
import gradio as gr
import torch
import warnings
import numpy as np
import os
import re
import time
from datetime import datetime
from typing import List, Dict, Optional, Tuple
import json
# Install required packages if not already installed
try:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
import faiss
except ImportError:
print("Installing required packages...")
import subprocess
import sys
packages = [
"transformers>=4.21.0",
"torch>=1.12.0",
"sentence-transformers",
"faiss-cpu",
"accelerate",
"bitsandbytes",
"datasets",
"numpy",
"sacremoses",
"scipy"
]
for package in packages:
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
except Exception as e:
print(f"Failed to install {package}: {e}")
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
import faiss
# Suppress warnings
warnings.filterwarnings('ignore')
class GradioBioGPTChatbot:
def __init__(self, use_gpu=False, use_8bit=False): # Default to CPU for HF Spaces
"""Initialize BioGPT chatbot for Gradio deployment"""
self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
self.use_8bit = use_8bit and torch.cuda.is_available()
print(f"πŸ”§ Initializing on device: {self.device}")
# Initialize components with error handling
try:
self.setup_embeddings()
self.setup_faiss_index()
self.setup_biogpt()
except Exception as e:
print(f"❌ Initialization error: {e}")
self.model = None
self.tokenizer = None
self.embedding_model = None
# Conversation tracking
self.conversation_history = []
self.knowledge_chunks = []
self.is_data_loaded = False
def setup_embeddings(self):
"""Setup medical-optimized embeddings with error handling"""
try:
print("πŸ”„ Loading embedding model...")
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension()
self.use_embeddings = True
print("βœ… Embeddings loaded successfully")
except Exception as e:
print(f"❌ Embeddings setup failed: {e}")
self.embedding_model = None
self.embedding_dim = 384
self.use_embeddings = False
def setup_faiss_index(self):
"""Setup FAISS for vector search with error handling"""
try:
print("πŸ”„ Setting up FAISS index...")
self.faiss_index = faiss.IndexFlatIP(self.embedding_dim)
self.faiss_ready = True
print("βœ… FAISS index ready")
except Exception as e:
print(f"❌ FAISS setup failed: {e}")
self.faiss_index = None
self.faiss_ready = False
def setup_biogpt(self):
"""Setup BioGPT model with optimizations and fallbacks"""
print("πŸ”„ Loading BioGPT model...")
# Try BioGPT first, with fallbacks
models_to_try = [
"microsoft/BioGPT", # Smaller version first
"microsoft/DialoGPT-medium", # Fallback 1
"microsoft/DialoGPT-small", # Fallback 2
"gpt2" # Final fallback
]
for model_name in models_to_try:
try:
print(f"πŸ”„ Trying model: {model_name}")
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Load model with minimal config for HF Spaces
if self.device == "cuda" and self.use_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
)
else:
quantization_config = None
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=quantization_config,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
trust_remote_code=True,
low_cpu_mem_usage=True # Important for HF Spaces
)
if self.device == "cpu":
self.model = self.model.to(self.device)
print(f"βœ… Successfully loaded: {model_name}")
break
except Exception as e:
print(f"❌ Failed to load {model_name}: {e}")
continue
else:
print("❌ All models failed to load")
self.model = None
self.tokenizer = None
def create_medical_chunks(self, text: str, chunk_size: int = 300) -> List[Dict]:
"""Create medically-optimized text chunks with smaller size for efficiency"""
chunks = []
# Split by medical sections first
medical_sections = self.split_by_medical_sections(text)
chunk_id = 0
for section in medical_sections:
if len(section.split()) > chunk_size:
# Split large sections by sentences
sentences = re.split(r'[.!?]+', section)
current_chunk = ""
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
if len(current_chunk.split()) + len(sentence.split()) < chunk_size:
current_chunk += sentence + ". "
else:
if current_chunk.strip():
chunks.append({
'id': chunk_id,
'text': current_chunk.strip(),
'medical_focus': self.identify_medical_focus(current_chunk)
})
chunk_id += 1
current_chunk = sentence + ". "
if current_chunk.strip():
chunks.append({
'id': chunk_id,
'text': current_chunk.strip(),
'medical_focus': self.identify_medical_focus(current_chunk)
})
chunk_id += 1
else:
if section.strip(): # Don't add empty sections
chunks.append({
'id': chunk_id,
'text': section,
'medical_focus': self.identify_medical_focus(section)
})
chunk_id += 1
return chunks
def split_by_medical_sections(self, text: str) -> List[str]:
"""Split text by medical sections"""
section_patterns = [
r'\n\s*(?:SYMPTOMS?|TREATMENT|DIAGNOSIS|CAUSES?|PREVENTION|MANAGEMENT).*?\n',
r'\n\s*\d+\.\s+',
r'\n\n+'
]
sections = [text]
for pattern in section_patterns:
new_sections = []
for section in sections:
splits = re.split(pattern, section, flags=re.IGNORECASE)
new_sections.extend([s.strip() for s in splits if len(s.strip()) > 50]) # Reduced minimum length
sections = new_sections
return sections
def identify_medical_focus(self, text: str) -> str:
"""Identify the medical focus of a text chunk"""
text_lower = text.lower()
categories = {
'pediatric_symptoms': ['fever', 'cough', 'rash', 'vomiting', 'diarrhea', 'child', 'baby', 'infant'],
'treatments': ['treatment', 'therapy', 'medication', 'antibiotics', 'medicine'],
'diagnosis': ['diagnosis', 'diagnostic', 'symptoms', 'signs', 'condition'],
'emergency': ['emergency', 'urgent', 'serious', 'hospital', 'call doctor'],
'prevention': ['prevention', 'vaccine', 'immunization', 'avoid', 'prevent']
}
for category, keywords in categories.items():
if any(keyword in text_lower for keyword in keywords):
return category
return 'general_medical'
def load_medical_data_from_file(self, file_path: str) -> Tuple[str, bool]:
"""Load medical data from uploaded file with better error handling"""
if not file_path or not os.path.exists(file_path):
return "❌ No file uploaded or file not found.", False
try:
print(f"πŸ”„ Processing file: {file_path}")
# Read file with encoding detection
encodings_to_try = ['utf-8', 'utf-8-sig', 'latin-1', 'cp1252']
text = None
for encoding in encodings_to_try:
try:
with open(file_path, 'r', encoding=encoding) as f:
text = f.read()
print(f"βœ… File read successfully with {encoding} encoding")
break
except UnicodeDecodeError:
continue
if text is None:
return "❌ Could not read file. Please ensure it's a valid text file.", False
if len(text.strip()) < 100:
return "❌ File appears to be too short or empty. Please upload a substantial medical text.", False
# Create chunks
print("πŸ”„ Creating medical chunks...")
chunks = self.create_medical_chunks(text)
if not chunks:
return "❌ No valid medical content found in the file.", False
self.knowledge_chunks = chunks
print(f"βœ… Created {len(chunks)} chunks")
# Generate embeddings if available
if self.use_embeddings and self.embedding_model and self.faiss_ready:
print("πŸ”„ Generating embeddings...")
success = self.generate_embeddings_and_index(chunks)
if success:
self.is_data_loaded = True
return f"βœ… Medical data loaded successfully! {len(chunks)} chunks processed with vector search.", True
self.is_data_loaded = True
return f"βœ… Medical data loaded successfully! {len(chunks)} chunks processed (keyword search mode).", True
except Exception as e:
print(f"❌ Error processing file: {e}")
return f"❌ Error loading file: {str(e)}", False
def generate_embeddings_and_index(self, chunks: List[Dict]) -> bool:
"""Generate embeddings and add to FAISS index with error handling"""
try:
print("πŸ”„ Generating embeddings...")
texts = [chunk['text'] for chunk in chunks]
# Process in batches to avoid memory issues
batch_size = 32
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
batch_embeddings = self.embedding_model.encode(batch_texts, show_progress_bar=False)
all_embeddings.append(batch_embeddings)
embeddings = np.vstack(all_embeddings)
self.faiss_index.add(embeddings.astype('float32'))
print(f"βœ… Added {len(embeddings)} embeddings to FAISS index")
return True
except Exception as e:
print(f"❌ Embedding generation failed: {e}")
return False
def retrieve_medical_context(self, query: str, n_results: int = 3) -> List[str]:
"""Retrieve relevant medical context with fallback"""
if not self.knowledge_chunks:
return []
if self.use_embeddings and self.embedding_model and self.faiss_ready:
try:
query_embedding = self.embedding_model.encode([query])
distances, indices = self.faiss_index.search(query_embedding.astype('float32'), n_results)
context_chunks = []
for i in indices[0]:
if i != -1 and i < len(self.knowledge_chunks):
context_chunks.append(self.knowledge_chunks[i]['text'])
if context_chunks:
return context_chunks
except Exception as e:
print(f"❌ Embedding search failed: {e}")
# Fallback to keyword search
return self.keyword_search_medical(query, n_results)
def keyword_search_medical(self, query: str, n_results: int) -> List[str]:
"""Medical-focused keyword search"""
if not self.knowledge_chunks:
return []
query_words = set(query.lower().split())
chunk_scores = []
for chunk_info in self.knowledge_chunks:
chunk_text = chunk_info['text']
chunk_words = set(chunk_text.lower().split())
word_overlap = len(query_words.intersection(chunk_words))
base_score = word_overlap / len(query_words) if query_words else 0
# Boost medical content
medical_boost = 0
if chunk_info.get('medical_focus') in ['pediatric_symptoms', 'treatments', 'diagnosis']:
medical_boost = 0.3
final_score = base_score + medical_boost
if final_score > 0:
chunk_scores.append((final_score, chunk_text))
chunk_scores.sort(reverse=True)
return [chunk for _, chunk in chunk_scores[:n_results]]
def generate_biogpt_response(self, context: str, query: str) -> str:
"""Generate medical response using loaded model"""
if not self.model or not self.tokenizer:
return "Medical AI model is not available. Using fallback response based on retrieved context."
try:
# Simplified prompt for better compatibility
prompt = f"Context: {context[:600]}\n\nQuestion: {query}\n\nAnswer:"
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=512, # Reduced for efficiency
padding=True
)
if self.device == "cuda":
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=100, # Reduced for efficiency
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=self.tokenizer.eos_token_id,
repetition_penalty=1.1,
no_repeat_ngram_size=3
)
full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if "Answer:" in full_response:
generated_response = full_response.split("Answer:")[-1].strip()
else:
generated_response = full_response[len(prompt):].strip()
return self.clean_medical_response(generated_response) if generated_response else self.fallback_response(context, query)
except Exception as e:
print(f"❌ Generation failed: {e}")
return self.fallback_response(context, query)
def clean_medical_response(self, response: str) -> str:
"""Clean and format medical response"""
if not response:
return "I couldn't generate a specific response. Please consult a healthcare professional."
# Remove incomplete sentences and clean up
sentences = re.split(r'[.!?]+', response)
clean_sentences = []
for sentence in sentences:
sentence = sentence.strip()
if len(sentence) > 15 and not sentence.endswith(('and', 'or', 'but', 'however', 'the', 'a', 'an')):
clean_sentences.append(sentence)
if len(clean_sentences) >= 2: # Limit to 2 sentences for clarity
break
if clean_sentences:
cleaned = '. '.join(clean_sentences) + '.'
else:
cleaned = response[:150] + '...' if len(response) > 150 else response
return cleaned
def fallback_response(self, context: str, query: str) -> str:
"""Fallback response when model generation fails"""
if not context:
return "I don't have specific information about this topic in my medical database. Please consult with a healthcare professional."
# Extract most relevant sentences from context
sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20]
if sentences:
# Return first 1-2 most relevant sentences
response = sentences[0]
if len(sentences) > 1 and len(response) < 100:
response += '. ' + sentences[1]
response += '.'
else:
response = context[:200] + '...' if len(context) > 200 else context
return response
def handle_conversational_interactions(self, query: str) -> Optional[str]:
"""Handle conversational interactions"""
query_lower = query.lower().strip()
# Greetings
if query_lower in ['hello', 'hi', 'hey', 'good morning', 'good afternoon']:
if not self.is_data_loaded:
return "πŸ‘‹ Hello! I'm your medical AI assistant. Please upload your medical data file first, then ask me any health-related questions!"
else:
return \"πŸ‘‹ Hello again! I'm ready to help. Ask me any medical question related to your uploaded data.\"
# Thanks
if any(thanks in query_lower for thanks in ['thank you', 'thanks', 'thx', 'appreciate']):
return "πŸ™ You're welcome! Remember to always consult healthcare professionals for medical decisions. Feel free to ask more questions!"
# Goodbyes
if any(bye in query_lower for bye in ['bye', 'goodbye', 'see you', 'farewell']):
return "πŸ‘‹ Goodbye! Take care and stay healthy! πŸ₯"
# Help/About
if any(help_word in query_lower for help_word in ['help', 'what can you do', 'how do you work']):
return """πŸ€– **Medical AI Assistant**
I can help with:
β€’ Medical information and conditions
β€’ Symptom understanding
β€’ Treatment information
β€’ When to seek medical care
**How to use:**
1. Upload your medical data file
2. Ask specific medical questions
3. Get evidence-based information
⚠️ **Important:** I provide educational information only. Always consult healthcare professionals for medical advice."""
return None
def chat_interface(self, message: str, history: List[List[str]]) -> Tuple[str, List[List[str]]]:
"""Main chat interface for Gradio"""
if not message.strip():
return "", history
# Check if data is loaded
if not self.is_data_loaded:
response = "⚠️ Please upload your medical data file first using the file upload above before asking questions."
history.append([message, response])
return "", history
# Handle conversational interactions
conversational_response = self.handle_conversational_interactions(message)
if conversational_response:
history.append([message, conversational_response])
return "", history
# Process medical query
try:
context = self.retrieve_medical_context(message)
if not context:
response = "I don't have specific information about this topic in my medical database. Please consult with a healthcare professional for personalized medical advice."
else:
main_context = '\n\n'.join(context)
medical_response = self.generate_biogpt_response(main_context, message)
response = f"🩺 **Medical Information:** {medical_response}\n\n⚠️ **Important:** This information is for educational purposes only. Always consult with qualified healthcare professionals for medical diagnosis, treatment, and personalized advice."
# Add to conversation history
self.conversation_history.append({
'query': message,
'response': response,
'timestamp': datetime.now().isoformat()
})
history.append([message, response])
return "", history
except Exception as e:
print(f"❌ Chat interface error: {e}")
error_response = "I encountered an error processing your question. Please try again or consult a healthcare professional."
history.append([message, error_response])
return "", history
# Initialize the chatbot with error handling
print("πŸš€ Initializing Medical AI Assistant...")
try:
chatbot = GradioBioGPTChatbot(use_gpu=False, use_8bit=False) # CPU-optimized for HF Spaces
print("βœ… Chatbot initialized successfully")
except Exception as e:
print(f"❌ Chatbot initialization failed: {e}")
chatbot = None
def upload_and_process_file(file):
"""Handle file upload and processing"""
if file is None:
return "❌ No file uploaded."
if chatbot is None:
return "❌ Chatbot not initialized properly. Please refresh the page."
try:
message, success = chatbot.load_medical_data_from_file(file)
return message
except Exception as e:
return f"❌ Error processing file: {str(e)}"
# Create Gradio Interface
def create_gradio_interface():
"""Create and launch Gradio interface"""
with gr.Blocks(
title="πŸ₯ Medical AI Assistant",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
.chat-message {
border-radius: 10px !important;
}
"""
) as demo:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1>πŸ₯ Medical AI Assistant</h1>
<p style="font-size: 18px; color: #666;">
AI-powered medical information assistant
</p>
<p style="color: #888;">
⚠️ For educational purposes only. Always consult healthcare professionals for medical advice.
</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.HTML("<h3>πŸ“ Upload Medical Data</h3>")
file_upload = gr.File(
label="Upload Medical Text File (.txt)",
file_types=[".txt"],
type="filepath"
)
upload_status = gr.Textbox(
label="Upload Status",
value="πŸ“‹ Please upload your medical data file to begin...",
interactive=False,
lines=3
)
gr.HTML("""
<div style="margin-top: 20px; padding: 15px; background-color: #f0f8ff; border-radius: 10px;">
<h4>πŸ’‘ How to Use:</h4>
<ol>
<li>Upload your medical text file (.txt format)</li>
<li>Wait for processing confirmation</li>
<li>Start asking medical questions!</li>
</ol>
<h4>πŸ“ Example Questions:</h4>
<ul>
<li>"What causes fever in children?"</li>
<li>"How to treat a persistent cough?"</li>
<li>"When should I call the doctor?"</li>
<li>"Signs of dehydration in infants?"</li>
</ul>
</div>
""")
with gr.Column(scale=2):
gr.HTML("<h3>πŸ’¬ Medical Consultation</h3>")
chatbot_interface = gr.Chatbot(
label="Medical AI Chat",
height=500,
bubble_full_width=False
)
msg_input = gr.Textbox(
label="Your Medical Question",
placeholder="Ask me about health topics, symptoms, treatments, or when to seek care...",
lines=2
)
with gr.Row():
send_btn = gr.Button("🩺 Send Question", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat", variant="secondary")
# Event handlers with error handling
def safe_upload_handler(file):
try:
return upload_and_process_file(file)
except Exception as e:
return f"❌ Upload error: {str(e)}"
def safe_chat_handler(message, history):
try:
if chatbot is None:
return "", history + [[message, "❌ System error. Please refresh the page."]]
return chatbot.chat_interface(message, history)
except Exception as e:
return "", history + [[message, f"❌ Error: {str(e)}"]]
file_upload.change(
fn=safe_upload_handler,
inputs=[file_upload],
outputs=[upload_status]
)
msg_input.submit(
fn=safe_chat_handler,
inputs=[msg_input, chatbot_interface],
outputs=[msg_input, chatbot_interface]
)
send_btn.click(
fn=safe_chat_handler,
inputs=[msg_input, chatbot_interface],
outputs=[msg_input, chatbot_interface]
)
clear_btn.click(
fn=lambda: ([], ""),
outputs=[chatbot_interface, msg_input]
)
gr.HTML("""
<div style="text-align: center; margin-top: 30px; padding: 20px; background-color: #fff3cd; border-radius: 10px;">
<h4>⚠️ Medical Disclaimer</h4>
<p>This AI assistant provides educational medical information only and is not a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of qualified healthcare providers with questions about medical conditions.</p>
</div>
""")
return demo
if __name__ == "__main__":
# Create and launch the Gradio interface
demo = create_gradio_interface()
print("🌐 Launching Gradio interface...")
print("πŸ“‹ Upload your medical data file and start chatting!")
# Launch with HF Spaces optimized settings
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
show_tips=False,
enable_queue=True,
max_threads=40
)