Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Model management for Frame 0 Laboratory for MIA | |
BAGEL 7B integration via API calls | |
""" | |
import spaces | |
import logging | |
import tempfile | |
import os | |
from typing import Optional, Dict, Any, Tuple | |
from PIL import Image | |
from gradio_client import Client, handle_file | |
from config import get_device_config | |
from utils import clean_memory, safe_execute | |
logger = logging.getLogger(__name__) | |
class BaseImageAnalyzer: | |
"""Base class for image analysis models""" | |
def __init__(self): | |
self.is_initialized = False | |
self.device_config = get_device_config() | |
def initialize(self) -> bool: | |
"""Initialize the model""" | |
raise NotImplementedError | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image and return description""" | |
raise NotImplementedError | |
def cleanup(self) -> None: | |
"""Clean up model resources""" | |
clean_memory() | |
class BagelAPIAnalyzer(BaseImageAnalyzer): | |
"""BAGEL 7B model via API calls to working Space""" | |
def __init__(self): | |
super().__init__() | |
self.client = None | |
self.space_url = "Malaji71/Bagel-7B-Demo" | |
self.api_endpoint = "/image_understanding" | |
def initialize(self) -> bool: | |
"""Initialize BAGEL API client""" | |
if self.is_initialized: | |
return True | |
try: | |
logger.info("Initializing BAGEL API client...") | |
self.client = Client(self.space_url) | |
self.is_initialized = True | |
logger.info("BAGEL API client initialized successfully") | |
return True | |
except Exception as e: | |
logger.error(f"BAGEL API client initialization failed: {e}") | |
return False | |
def _save_temp_image(self, image: Image.Image) -> str: | |
"""Save image to temporary file for API call""" | |
try: | |
# Create temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
temp_path = temp_file.name | |
temp_file.close() | |
# Save image | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image.save(temp_path, 'PNG') | |
return temp_path | |
except Exception as e: | |
logger.error(f"Failed to save temporary image: {e}") | |
return None | |
def _cleanup_temp_file(self, file_path: str): | |
"""Clean up temporary file""" | |
try: | |
if file_path and os.path.exists(file_path): | |
os.unlink(file_path) | |
except Exception as e: | |
logger.warning(f"Failed to cleanup temp file: {e}") | |
def analyze_image(self, image: Image.Image, prompt: str = None) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image using BAGEL API""" | |
if not self.is_initialized: | |
success = self.initialize() | |
if not success: | |
return "BAGEL API not available", {"error": "API initialization failed"} | |
temp_path = None | |
try: | |
# Default prompt for detailed image analysis | |
if prompt is None: | |
prompt = "Describe this image in rich detail as a single flowing paragraph. Include the visual elements, composition, lighting, colors, artistic style, mood, and atmosphere. Write it as a continuous narrative description without using numbered lists or bullet points. Focus on creating a vivid, cohesive description that captures the essence and details of the image." | |
# Save image to temporary file | |
temp_path = self._save_temp_image(image) | |
if not temp_path: | |
return "Image processing failed", {"error": "Could not save image"} | |
logger.info("Calling BAGEL API for image analysis...") | |
# Call BAGEL API | |
result = self.client.predict( | |
image=handle_file(temp_path), | |
prompt=prompt, | |
show_thinking=False, | |
do_sample=False, | |
text_temperature=0.3, | |
max_new_tokens=512, | |
api_name=self.api_endpoint | |
) | |
# Extract response (API returns tuple: (image_result, text_response)) | |
if isinstance(result, tuple) and len(result) >= 2: | |
description = result[1] if result[1] else result[0] | |
else: | |
description = str(result) | |
# Clean up the description | |
if isinstance(description, str) and description.strip(): | |
description = description.strip() | |
else: | |
description = "Detailed image analysis completed successfully" | |
# Prepare metadata | |
metadata = { | |
"model": "BAGEL-7B-API", | |
"device": "api", | |
"confidence": 0.9, | |
"api_endpoint": self.api_endpoint, | |
"space_url": self.space_url, | |
"prompt_used": prompt, | |
"response_length": len(description) | |
} | |
logger.info(f"BAGEL API analysis complete: {len(description)} characters") | |
return description, metadata | |
except Exception as e: | |
logger.error(f"BAGEL API analysis failed: {e}") | |
return "API analysis failed", {"error": str(e), "model": "BAGEL-7B-API"} | |
finally: | |
# Always cleanup temporary file | |
if temp_path: | |
self._cleanup_temp_file(temp_path) | |
def analyze_for_flux_prompt(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image specifically for FLUX prompt generation""" | |
flux_prompt = """Create a detailed, flowing description of this image suitable for FLUX generation. Write as a single coherent paragraph describing the photographic style, composition, lighting, colors, mood, and technical details. Focus on artistic and photographic elements that would help recreate this image. Avoid numbered lists or bullet points - write it as a natural, descriptive narrative.""" | |
return self.analyze_image(image, flux_prompt) | |
def cleanup(self) -> None: | |
"""Clean up API client resources""" | |
try: | |
if hasattr(self, 'client'): | |
self.client = None | |
super().cleanup() | |
logger.info("BAGEL API resources cleaned up") | |
except Exception as e: | |
logger.warning(f"BAGEL API cleanup warning: {e}") | |
class FallbackAnalyzer(BaseImageAnalyzer): | |
"""Simple fallback analyzer when BAGEL API is not available""" | |
def __init__(self): | |
super().__init__() | |
def initialize(self) -> bool: | |
"""Fallback is always ready""" | |
self.is_initialized = True | |
return True | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Provide basic image description""" | |
try: | |
# Basic image analysis | |
width, height = image.size | |
mode = image.mode | |
# Simple descriptive text based on image properties | |
aspect_ratio = width / height | |
if aspect_ratio > 1.5: | |
orientation = "landscape" | |
camera_suggestion = "wide-angle lens, landscape photography" | |
elif aspect_ratio < 0.75: | |
orientation = "portrait" | |
camera_suggestion = "portrait lens, shallow depth of field" | |
else: | |
orientation = "square" | |
camera_suggestion = "standard lens, balanced composition" | |
description = f"A {orientation} format image with professional composition. The image shows clear detail and good visual balance, suitable for high-quality reproduction. Recommended camera setup: {camera_suggestion}, professional lighting with careful attention to exposure and color balance." | |
metadata = { | |
"model": "Fallback", | |
"device": "cpu", | |
"confidence": 0.6, | |
"image_size": f"{width}x{height}", | |
"color_mode": mode, | |
"orientation": orientation, | |
"aspect_ratio": round(aspect_ratio, 2) | |
} | |
return description, metadata | |
except Exception as e: | |
logger.error(f"Fallback analysis failed: {e}") | |
return "Professional image suitable for detailed analysis and prompt generation", {"error": str(e), "model": "Fallback"} | |
class ModelManager: | |
"""Manager for handling image analysis models""" | |
def __init__(self, preferred_model: str = "bagel-api"): | |
self.preferred_model = preferred_model | |
self.analyzers = {} | |
self.current_analyzer = None | |
def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]: | |
"""Get or create analyzer for specified model""" | |
model_name = model_name or self.preferred_model | |
if model_name not in self.analyzers: | |
if model_name == "bagel-api": | |
self.analyzers[model_name] = BagelAPIAnalyzer() | |
elif model_name == "fallback": | |
self.analyzers[model_name] = FallbackAnalyzer() | |
else: | |
logger.warning(f"Unknown model: {model_name}, using fallback") | |
model_name = "fallback" | |
self.analyzers[model_name] = FallbackAnalyzer() | |
return self.analyzers[model_name] | |
def analyze_image(self, image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image with specified or preferred model""" | |
# Try preferred model first | |
analyzer = self.get_analyzer(model_name) | |
if analyzer is None: | |
return "No analyzer available", {"error": "Model not found"} | |
# Choose analysis method based on type | |
if analysis_type == "flux" and hasattr(analyzer, 'analyze_for_flux_prompt'): | |
success, result = safe_execute(analyzer.analyze_for_flux_prompt, image) | |
else: | |
success, result = safe_execute(analyzer.analyze_image, image) | |
if success and result[1].get("error") is None: | |
return result | |
else: | |
# Fallback to simple analyzer if main model fails | |
logger.warning(f"Primary model failed, using fallback: {result}") | |
fallback_analyzer = self.get_analyzer("fallback") | |
fallback_success, fallback_result = safe_execute(fallback_analyzer.analyze_image, image) | |
if fallback_success: | |
return fallback_result | |
else: | |
return "All analyzers failed", {"error": "Complete analysis failure"} | |
def cleanup_all(self) -> None: | |
"""Clean up all model resources""" | |
for analyzer in self.analyzers.values(): | |
analyzer.cleanup() | |
self.analyzers.clear() | |
clean_memory() | |
logger.info("All analyzers cleaned up") | |
# Global model manager instance | |
model_manager = ModelManager(preferred_model="bagel-api") | |
def analyze_image(image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: | |
""" | |
Convenience function for image analysis using BAGEL API | |
Args: | |
image: PIL Image to analyze | |
model_name: Optional model name ("bagel-api" or "fallback") | |
analysis_type: Type of analysis ("detailed" or "flux") | |
Returns: | |
Tuple of (description, metadata) | |
""" | |
return model_manager.analyze_image(image, model_name, analysis_type) | |
# Export main components | |
__all__ = [ | |
"BaseImageAnalyzer", | |
"BagelAPIAnalyzer", | |
"FallbackAnalyzer", | |
"ModelManager", | |
"model_manager", | |
"analyze_image" | |
] |