# Setup and Installation import torch print("๐Ÿ–ฅ๏ธ System Check:") print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"GPU device: {torch.cuda.get_device_name(0)}") print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") else: print("โš ๏ธ No GPU detected - BioGPT will run on CPU") print("\n๐Ÿ”ง Loading required packages...") # Import Libraries import os import re import torch import warnings import numpy as np import faiss # FAISS for vector search from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig ) from sentence_transformers import SentenceTransformer from typing import List, Dict, Optional import time from datetime import datetime import json import pickle # Suppress warnings for cleaner output warnings.filterwarnings('ignore') print("๐Ÿ“š Libraries imported successfully!") print(f"๐Ÿ” FAISS version: {faiss.__version__}") print("๐ŸŽฏ Using FAISS for vector search") # BioGPT Medical Chatbot Class class ColabBioGPTChatbot: def __init__(self, use_gpu=True, use_8bit=True): """Initialize BioGPT chatbot optimized for deployment""" print("๐Ÿฅ Initializing Professional BioGPT Medical Chatbot...") # Force CPU for HF Spaces if needed self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu" self.use_8bit = use_8bit and torch.cuda.is_available() print(f"๐Ÿ–ฅ๏ธ Using device: {self.device}") if self.use_8bit: print("๐Ÿ’พ Using 8-bit quantization for memory efficiency") # Setup components self.setup_embeddings() self.setup_faiss_index() self.setup_biogpt() # Conversation tracking self.conversation_history = [] self.knowledge_chunks = [] print("โœ… BioGPT Medical Chatbot ready for professional medical assistance!") def setup_embeddings(self): """Setup medical-optimized embeddings""" print("๐Ÿ”ง Loading medical embeddings...") try: # Use a smaller, more efficient model for deployment self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') self.embedding_dim = self.embedding_model.get_sentence_embedding_dimension() print(f"โœ… Embeddings loaded (dimension: {self.embedding_dim})") self.use_embeddings = True except Exception as e: print(f"โš ๏ธ Embeddings failed: {e}") self.embedding_model = None self.embedding_dim = 384 self.use_embeddings = False def setup_faiss_index(self): """Setup faiss for CPU-based vector search""" print("๐Ÿ”ง Setting up FAISS vector database...") try: print('Using CPU FAISS index for maximum compatibility') self.faiss_index = faiss.IndexFlatIP(self.embedding_dim) self.use_gpu_faiss = False self.faiss_ready = True self.collection = self.faiss_index print("โœ… FAISS CPU index initialized successfully") except Exception as e: print(f"โŒ FAISS setup failed: {e}") self.faiss_index = None self.faiss_ready = False self.collection = None def setup_biogpt(self): """Setup BioGPT model with optimizations for deployment""" print("๐Ÿง  Loading BioGPT model...") # Try BioGPT first, fallback to smaller models if needed model_options = [ "microsoft/BioGPT-Large", "microsoft/BioGPT", # Smaller version "microsoft/DialoGPT-medium", # Fallback "gpt2" # Final fallback ] for model_name in model_options: try: print(f" Attempting to load: {model_name}") # Setup quantization config for memory efficiency if self.use_8bit and "BioGPT" in model_name: quantization_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, ) else: quantization_config = None # Load tokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_name) # Set padding token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Load model with proper settings for deployment start_time = time.time() model_kwargs = { "torch_dtype": torch.float16 if self.device == "cuda" else torch.float32, "trust_remote_code": True, "low_cpu_mem_usage": True, # Important for deployment } if quantization_config: model_kwargs["quantization_config"] = quantization_config model_kwargs["device_map"] = "auto" self.model = AutoModelForCausalLM.from_pretrained( model_name, **model_kwargs ) # Move to device if not using device_map if self.device == "cuda" and quantization_config is None: self.model = self.model.to(self.device) load_time = time.time() - start_time print(f"โœ… {model_name} loaded successfully! ({load_time:.1f} seconds)") # Test the model self.test_model() break # Success, exit the loop except Exception as e: print(f"โŒ {model_name} loading failed: {e}") if model_name == model_options[-1]: # Last option failed print("โŒ All models failed to load") self.model = None self.tokenizer = None continue def test_model(self): """Test the loaded model with a simple query""" print("๐Ÿงช Testing model...") try: test_prompt = "Fever in children can be caused by" inputs = self.tokenizer(test_prompt, return_tensors="pt") if self.device == "cuda": inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=20, do_sample=True, temperature=0.7, pad_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"โœ… Model test successful!") print(f" Test response: {response}") except Exception as e: print(f"โš ๏ธ Model test failed: {e}") def load_medical_data(self, file_path: str): """Load and process medical data with progress tracking""" print(f"๐Ÿ“– Loading medical data from {file_path}...") try: with open(file_path, 'r', encoding='utf-8') as f: text = f.read() print(f"๐Ÿ“„ File loaded: {len(text):,} characters") except FileNotFoundError: print(f"โŒ File {file_path} not found!") return False except Exception as e: print(f"โŒ Error loading file: {e}") return False # Create chunks optimized for medical content print("๐Ÿ“ Creating medical-optimized chunks...") chunks = self.create_medical_chunks(text) print(f"๐Ÿ“‹ Created {len(chunks)} medical chunks") self.knowledge_chunks = chunks # Generate embeddings with progress and add to FAISS index if self.use_embeddings and self.embedding_model and self.faiss_ready: return self.generate_embeddings_with_progress(chunks) print("โœ… Medical data loaded (text search mode)") return True def create_medical_chunks(self, text: str, chunk_size: int = 400) -> List[Dict]: """Create medically-optimized text chunks""" chunks = [] # Split by medical sections first medical_sections = self.split_by_medical_sections(text) chunk_id = 0 for section in medical_sections: if len(section.split()) > chunk_size: # Split large sections by sentences sentences = re.split(r'[.!?]+', section) current_chunk = "" for sentence in sentences: sentence = sentence.strip() if not sentence: continue if len(current_chunk.split()) + len(sentence.split()) < chunk_size: current_chunk += sentence + ". " else: if current_chunk.strip(): chunks.append({ 'id': chunk_id, 'text': current_chunk.strip(), 'medical_focus': self.identify_medical_focus(current_chunk) }) chunk_id += 1 current_chunk = sentence + ". " if current_chunk.strip(): chunks.append({ 'id': chunk_id, 'text': current_chunk.strip(), 'medical_focus': self.identify_medical_focus(current_chunk) }) chunk_id += 1 else: chunks.append({ 'id': chunk_id, 'text': section, 'medical_focus': self.identify_medical_focus(section) }) chunk_id += 1 return chunks def split_by_medical_sections(self, text: str) -> List[str]: """Split text by medical sections""" # Look for medical section headers section_patterns = [ r'\n\s*(?:SYMPTOMS?|TREATMENT|DIAGNOSIS|CAUSES?|PREVENTION|MANAGEMENT).*?\n', r'\n\s*\d+\.\s+', # Numbered sections r'\n\n+' # Paragraph breaks ] sections = [text] for pattern in section_patterns: new_sections = [] for section in sections: splits = re.split(pattern, section, flags=re.IGNORECASE) new_sections.extend([s.strip() for s in splits if len(s.strip()) > 100]) sections = new_sections return sections def identify_medical_focus(self, text: str) -> str: """Identify the medical focus of a text chunk""" text_lower = text.lower() # Medical categories categories = { 'pediatric_symptoms': ['fever', 'cough', 'rash', 'vomiting', 'diarrhea'], 'treatments': ['treatment', 'therapy', 'medication', 'antibiotics'], 'diagnosis': ['diagnosis', 'diagnostic', 'symptoms', 'signs'], 'emergency': ['emergency', 'urgent', 'serious', 'hospital'], 'prevention': ['prevention', 'vaccine', 'immunization', 'avoid'] } for category, keywords in categories.items(): if any(keyword in text_lower for keyword in keywords): return category return 'general_medical' def generate_embeddings_with_progress(self, chunks: List[Dict]) -> bool: """Generate embeddings with progress tracking and add to FAISS index""" print("๐Ÿ”ฎ Generating medical embeddings and adding to FAISS index...") if not self.embedding_model or not self.faiss_index: print("โŒ Embedding model or FAISS index not available.") return False try: texts = [chunk['text'] for chunk in chunks] # Generate embeddings in batches with progress batch_size = 32 all_embeddings = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i:i+batch_size] batch_embeddings = self.embedding_model.encode(batch_texts, show_progress_bar=False) all_embeddings.extend(batch_embeddings) # Show progress progress = min(i + batch_size, len(texts)) print(f" Progress: {progress}/{len(texts)} chunks processed", end='\r') print(f"\n โœ… Generated embeddings for {len(texts)} chunks") # Add embeddings to FAISS index print("๐Ÿ’พ Adding embeddings to FAISS index...") self.faiss_index.add(np.array(all_embeddings)) print("โœ… Medical embeddings added to FAISS index successfully!") return True except Exception as e: print(f"โŒ Embedding generation or FAISS add failed: {e}") return False def retrieve_medical_context(self, query: str, n_results: int = 3) -> List[str]: """Retrieve relevant medical context using embeddings or keyword search""" if self.use_embeddings and self.embedding_model and self.faiss_ready: try: # Generate query embedding query_embedding = self.embedding_model.encode([query]) # Search for similar content in FAISS index distances, indices = self.faiss_index.search(np.array(query_embedding), n_results) # Retrieve the corresponding chunks context_chunks = [self.knowledge_chunks[i]['text'] for i in indices[0] if i != -1] if context_chunks: return context_chunks except Exception as e: print(f"โš ๏ธ Embedding search failed: {e}") # Fallback to keyword search print("โš ๏ธ Falling back to keyword search.") return self.keyword_search_medical(query, n_results) def keyword_search_medical(self, query: str, n_results: int) -> List[str]: """Medical-focused keyword search""" if not self.knowledge_chunks: return [] query_words = set(query.lower().split()) chunk_scores = [] for chunk_info in self.knowledge_chunks: chunk_text = chunk_info['text'] chunk_words = set(chunk_text.lower().split()) # Calculate relevance score word_overlap = len(query_words.intersection(chunk_words)) base_score = word_overlap / len(query_words) if query_words else 0 # Boost medical content medical_boost = 0 if chunk_info.get('medical_focus') in ['pediatric_symptoms', 'treatments', 'diagnosis']: medical_boost = 0.5 final_score = base_score + medical_boost if final_score > 0: chunk_scores.append((final_score, chunk_text)) # Return top matches chunk_scores.sort(reverse=True) return [chunk for _, chunk in chunk_scores[:n_results]] def generate_biogpt_response(self, context: str, query: str) -> str: """Generate medical response using BioGPT only""" if not self.model or not self.tokenizer: return "โš ๏ธ Medical AI model not available. This chatbot requires BioGPT for accurate medical information. Please check the setup or try restarting." try: # Create medical-focused prompt prompt = f"""Medical Context: {context[:800]} Question: {query} Medical Answer:""" # Tokenize input inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=1024 ) # Move inputs to the correct device if self.device == "cuda": inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate response with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=150, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=self.tokenizer.eos_token_id, repetition_penalty=1.1 ) # Decode response full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the generated part if "Medical Answer:" in full_response: generated_response = full_response.split("Medical Answer:")[-1].strip() else: generated_response = full_response[len(prompt):].strip() # Clean up response cleaned_response = self.clean_medical_response(generated_response) return cleaned_response except Exception as e: print(f"โš ๏ธ BioGPT generation failed: {e}") return "โš ๏ธ Unable to generate medical response. The medical AI model encountered an error. Please try rephrasing your question or contact support." def clean_medical_response(self, response: str) -> str: """Clean and format medical response""" # Remove incomplete sentences and limit length sentences = re.split(r'[.!?]+', response) clean_sentences = [] for sentence in sentences: sentence = sentence.strip() if len(sentence) > 10 and not sentence.endswith(('and', 'or', 'but', 'however')): clean_sentences.append(sentence) if len(clean_sentences) >= 3: # Limit to 3 sentences break if clean_sentences: cleaned = '. '.join(clean_sentences) + '.' else: cleaned = response[:200] + '...' if len(response) > 200 else response return cleaned def fallback_response(self, context: str, query: str) -> str: """Fallback response when BioGPT fails""" # Extract key sentences from context sentences = [s.strip() for s in context.split('.') if len(s.strip()) > 20] if sentences: response = sentences[0] + '.' if len(sentences) > 1: response += ' ' + sentences[1] + '.' else: response = context[:300] + '...' return response def handle_conversational_interactions(self, query: str) -> Optional[str]: """Handle comprehensive conversational interactions""" query_lower = query.lower().strip() # Use more specific patterns for greetings greeting_patterns = [ r'^\s*(hello|hi|hey|hiya|howdy)\s*$', r'^\s*(good morning|good afternoon|good evening|good day)\s*$', r'^\s*(what\'s up|whats up|sup|yo)\s*$', r'^\s*(greetings|salutations)\s*$', r'^\s*(how are you|how are you doing|how\'s it going|hows it going)\s*$', r'^\s*(good to meet you|nice to meet you|pleased to meet you)\s*$' ] for pattern in greeting_patterns: if re.match(pattern, query_lower): responses = [ "๐Ÿ‘‹ Hello! I'm BioGPT, your professional medical AI assistant specialized in pediatric medicine. I'm here to provide evidence-based medical information. What health concern can I help you with today?", "๐Ÿฅ Hi there! I'm a medical AI assistant powered by BioGPT, trained on medical literature. I can help answer questions about children's health and medical conditions. How can I assist you?", "๐Ÿ‘‹ Greetings! I'm your AI medical consultant, ready to help with pediatric health questions using the latest medical knowledge. What would you like to know about?" ] return np.random.choice(responses) # Handle thanks and other conversational patterns... # (keeping the rest of the conversational handling as before) # Return None if no conversational pattern matches return None def chat(self, query: str) -> str: """Main chat function with BioGPT medical-only responses""" if not query.strip(): return "Hello! I'm BioGPT, your professional medical AI assistant. How can I help you with pediatric medical questions today?" # Handle comprehensive conversational interactions first conversational_response = self.handle_conversational_interactions(query) if conversational_response: # Add to conversation history self.conversation_history.append({ 'query': query, 'response': conversational_response, 'timestamp': datetime.now().isoformat(), 'type': 'conversational' }) return conversational_response # Check if medical model is available if not self.model or not self.tokenizer: return "โš ๏ธ **Medical AI Unavailable**: This chatbot requires BioGPT for accurate medical information. The medical model failed to load. Please contact support or try restarting the application." if not self.knowledge_chunks: return "Please load medical data first to access the medical knowledge base." print(f"๐Ÿ” Processing medical query: {query}") # Retrieve relevant medical context using FAISS or keyword search start_time = time.time() context = self.retrieve_medical_context(query) retrieval_time = time.time() - start_time if not context: return "I don't have specific information about this topic in my medical database. Please consult with a healthcare professional for personalized medical advice." print(f" ๐Ÿ“š Context retrieved ({retrieval_time:.2f}s)") # Generate response with BioGPT start_time = time.time() main_context = '\n\n'.join(context) response = self.generate_biogpt_response(main_context, query) generation_time = time.time() - start_time print(f" ๐Ÿง  Response generated ({generation_time:.2f}s)") # Format final response final_response = f"๐Ÿฉบ **Medical Information:** {response}\n\nโš ๏ธ **Important:** This information is for educational purposes only. Always consult with qualified healthcare professionals for medical diagnosis, treatment, and personalized advice." # Add to conversation history self.conversation_history.append({ 'query': query, 'response': final_response, 'timestamp': datetime.now().isoformat(), 'retrieval_time': retrieval_time, 'generation_time': generation_time, 'type': 'medical' }) return final_response def get_conversation_summary(self) -> Dict: """Get conversation statistics""" if not self.conversation_history: return {"message": "No conversations yet"} # Filter medical conversations for performance stats medical_conversations = [h for h in self.conversation_history if h.get('type') == 'medical'] if not medical_conversations: return { "total_conversations": len(self.conversation_history), "medical_conversations": 0, "conversational_interactions": len(self.conversation_history), "model_info": "BioGPT" if self.model and "BioGPT" in str(type(self.model)) else "Fallback Model", "vector_search": "FAISS CPU" if self.faiss_ready else "Keyword Search", "device": self.device } avg_retrieval_time = sum(h.get('retrieval_time', 0) for h in medical_conversations) / len(medical_conversations) avg_generation_time = sum(h.get('generation_time', 0) for h in medical_conversations) / len(medical_conversations) return { "total_conversations": len(self.conversation_history), "medical_conversations": len(medical_conversations), "conversational_interactions": len(self.conversation_history) - len(medical_conversations), "avg_retrieval_time": f"{avg_retrieval_time:.2f}s", "avg_generation_time": f"{avg_generation_time:.2f}s", "model_info": "BioGPT" if self.model and "BioGPT" in str(type(self.model)) else "Fallback Model", "vector_search": "FAISS CPU" if self.faiss_ready else "Keyword Search", "device": self.device, "quantization": "8-bit" if self.use_8bit else "16-bit/32-bit" }