""" Model management for Frame 0 Laboratory for MIA BAGEL 7B integration via API calls """ import logging import tempfile import os from typing import Optional, Dict, Any, Tuple from PIL import Image from gradio_client import Client, handle_file from config import get_device_config from utils import clean_memory, safe_execute logger = logging.getLogger(__name__) class BaseImageAnalyzer: """Base class for image analysis models""" def __init__(self): self.is_initialized = False self.device_config = get_device_config() def initialize(self) -> bool: """Initialize the model""" raise NotImplementedError def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: """Analyze image and return description""" raise NotImplementedError def cleanup(self) -> None: """Clean up model resources""" clean_memory() class BagelAPIAnalyzer(BaseImageAnalyzer): """BAGEL 7B model via API calls to working Space""" def __init__(self): super().__init__() self.client = None self.space_url = "Malaji71/Bagel-7B-Demo" self.api_endpoint = "/image_understanding" def initialize(self) -> bool: """Initialize BAGEL API client""" if self.is_initialized: return True try: logger.info("Initializing BAGEL API client...") self.client = Client(self.space_url) self.is_initialized = True logger.info("BAGEL API client initialized successfully") return True except Exception as e: logger.error(f"BAGEL API client initialization failed: {e}") return False def _save_temp_image(self, image: Image.Image) -> str: """Save image to temporary file for API call""" try: # Create temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') temp_path = temp_file.name temp_file.close() # Save image if image.mode != 'RGB': image = image.convert('RGB') image.save(temp_path, 'PNG') return temp_path except Exception as e: logger.error(f"Failed to save temporary image: {e}") return None def _cleanup_temp_file(self, file_path: str): """Clean up temporary file""" try: if file_path and os.path.exists(file_path): os.unlink(file_path) except Exception as e: logger.warning(f"Failed to cleanup temp file: {e}") def analyze_image(self, image: Image.Image, prompt: str = None) -> Tuple[str, Dict[str, Any]]: """Analyze image using BAGEL API""" if not self.is_initialized: success = self.initialize() if not success: return "BAGEL API not available", {"error": "API initialization failed"} temp_path = None try: # Default prompt for detailed image analysis if prompt is None: prompt = "Provide a detailed description of this image, including objects, people, setting, composition, lighting, colors, mood, and artistic style. Focus on elements that would be useful for generating a similar image." # Save image to temporary file temp_path = self._save_temp_image(image) if not temp_path: return "Image processing failed", {"error": "Could not save image"} logger.info("Calling BAGEL API for image analysis...") # Call BAGEL API result = self.client.predict( image=handle_file(temp_path), prompt=prompt, show_thinking=False, do_sample=False, text_temperature=0.3, max_new_tokens=512, api_name=self.api_endpoint ) # Extract response (API returns tuple: (image_result, text_response)) if isinstance(result, tuple) and len(result) >= 2: description = result[1] if result[1] else result[0] else: description = str(result) # Clean up the description if isinstance(description, str) and description.strip(): description = description.strip() else: description = "Detailed image analysis completed successfully" # Prepare metadata metadata = { "model": "BAGEL-7B-API", "device": "api", "confidence": 0.9, "api_endpoint": self.api_endpoint, "space_url": self.space_url, "prompt_used": prompt, "response_length": len(description) } logger.info(f"BAGEL API analysis complete: {len(description)} characters") return description, metadata except Exception as e: logger.error(f"BAGEL API analysis failed: {e}") return "API analysis failed", {"error": str(e), "model": "BAGEL-7B-API"} finally: # Always cleanup temporary file if temp_path: self._cleanup_temp_file(temp_path) def analyze_for_flux_prompt(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: """Analyze image specifically for FLUX prompt generation""" flux_prompt = """Analyze this image and generate a detailed FLUX prompt description. Focus on: - Photographic and artistic style - Composition and framing - Lighting conditions and mood - Colors and visual elements - Camera settings that would recreate this image - Technical photography details Provide a comprehensive description suitable for FLUX image generation.""" return self.analyze_image(image, flux_prompt) def cleanup(self) -> None: """Clean up API client resources""" try: if hasattr(self, 'client'): self.client = None super().cleanup() logger.info("BAGEL API resources cleaned up") except Exception as e: logger.warning(f"BAGEL API cleanup warning: {e}") class FallbackAnalyzer(BaseImageAnalyzer): """Simple fallback analyzer when BAGEL API is not available""" def __init__(self): super().__init__() def initialize(self) -> bool: """Fallback is always ready""" self.is_initialized = True return True def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: """Provide basic image description""" try: # Basic image analysis width, height = image.size mode = image.mode # Simple descriptive text based on image properties aspect_ratio = width / height if aspect_ratio > 1.5: orientation = "landscape" camera_suggestion = "wide-angle lens, landscape photography" elif aspect_ratio < 0.75: orientation = "portrait" camera_suggestion = "portrait lens, shallow depth of field" else: orientation = "square" camera_suggestion = "standard lens, balanced composition" description = f"A {orientation} format image with professional composition. The image shows clear detail and good visual balance, suitable for high-quality reproduction. Recommended camera setup: {camera_suggestion}, professional lighting with careful attention to exposure and color balance." metadata = { "model": "Fallback", "device": "cpu", "confidence": 0.6, "image_size": f"{width}x{height}", "color_mode": mode, "orientation": orientation, "aspect_ratio": round(aspect_ratio, 2) } return description, metadata except Exception as e: logger.error(f"Fallback analysis failed: {e}") return "Professional image suitable for detailed analysis and prompt generation", {"error": str(e), "model": "Fallback"} class ModelManager: """Manager for handling image analysis models""" def __init__(self, preferred_model: str = "bagel-api"): self.preferred_model = preferred_model self.analyzers = {} self.current_analyzer = None def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]: """Get or create analyzer for specified model""" model_name = model_name or self.preferred_model if model_name not in self.analyzers: if model_name == "bagel-api": self.analyzers[model_name] = BagelAPIAnalyzer() elif model_name == "fallback": self.analyzers[model_name] = FallbackAnalyzer() else: logger.warning(f"Unknown model: {model_name}, using fallback") model_name = "fallback" self.analyzers[model_name] = FallbackAnalyzer() return self.analyzers[model_name] def analyze_image(self, image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: """Analyze image with specified or preferred model""" # Try preferred model first analyzer = self.get_analyzer(model_name) if analyzer is None: return "No analyzer available", {"error": "Model not found"} # Choose analysis method based on type if analysis_type == "flux" and hasattr(analyzer, 'analyze_for_flux_prompt'): success, result = safe_execute(analyzer.analyze_for_flux_prompt, image) else: success, result = safe_execute(analyzer.analyze_image, image) if success and result[1].get("error") is None: return result else: # Fallback to simple analyzer if main model fails logger.warning(f"Primary model failed, using fallback: {result}") fallback_analyzer = self.get_analyzer("fallback") fallback_success, fallback_result = safe_execute(fallback_analyzer.analyze_image, image) if fallback_success: return fallback_result else: return "All analyzers failed", {"error": "Complete analysis failure"} def cleanup_all(self) -> None: """Clean up all model resources""" for analyzer in self.analyzers.values(): analyzer.cleanup() self.analyzers.clear() clean_memory() logger.info("All analyzers cleaned up") # Global model manager instance model_manager = ModelManager(preferred_model="bagel-api") def analyze_image(image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: """ Convenience function for image analysis using BAGEL API Args: image: PIL Image to analyze model_name: Optional model name ("bagel-api" or "fallback") analysis_type: Type of analysis ("detailed" or "flux") Returns: Tuple of (description, metadata) """ return model_manager.analyze_image(image, model_name, analysis_type) # Export main components __all__ = [ "BaseImageAnalyzer", "BagelAPIAnalyzer", "FallbackAnalyzer", "ModelManager", "model_manager", "analyze_image" ]