Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Model management for Frame 0 Laboratory for MIA | |
BAGEL 7B integration via API calls | |
""" | |
import spaces | |
import logging | |
import tempfile | |
import os | |
from typing import Optional, Dict, Any, Tuple | |
from PIL import Image | |
from gradio_client import Client, handle_file | |
from config import get_device_config | |
from utils import clean_memory, safe_execute | |
logger = logging.getLogger(__name__) | |
class BaseImageAnalyzer: | |
"""Base class for image analysis models""" | |
def __init__(self): | |
self.is_initialized = False | |
self.device_config = get_device_config() | |
def initialize(self) -> bool: | |
"""Initialize the model""" | |
raise NotImplementedError | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image and return description""" | |
raise NotImplementedError | |
def cleanup(self) -> None: | |
"""Clean up model resources""" | |
clean_memory() | |
class BagelAPIAnalyzer(BaseImageAnalyzer): | |
"""BAGEL 7B model via API calls to working Space""" | |
def __init__(self): | |
super().__init__() | |
self.client = None | |
self.space_url = "Malaji71/Bagel-7B-Demo" | |
self.api_endpoint = "/image_understanding" | |
def initialize(self) -> bool: | |
"""Initialize BAGEL API client""" | |
if self.is_initialized: | |
return True | |
try: | |
logger.info("Initializing BAGEL API client...") | |
self.client = Client(self.space_url) | |
self.is_initialized = True | |
logger.info("BAGEL API client initialized successfully") | |
return True | |
except Exception as e: | |
logger.error(f"BAGEL API client initialization failed: {e}") | |
return False | |
def _extract_camera_setup(self, description: str) -> Optional[str]: | |
"""Extract camera setup recommendation from BAGEL response""" | |
try: | |
# Look for CAMERA_SETUP section | |
if "CAMERA_SETUP:" in description: | |
parts = description.split("CAMERA_SETUP:") | |
if len(parts) > 1: | |
camera_part = parts[1].strip() | |
# Clean up any additional formatting | |
camera_part = camera_part.replace("\n", " ").strip() | |
return camera_part | |
# Alternative patterns for camera recommendations | |
camera_patterns = [ | |
"Shot on ", | |
"Camera: ", | |
"Equipment: ", | |
"Recommended camera:", | |
"Camera setup:" | |
] | |
for pattern in camera_patterns: | |
if pattern in description: | |
# Extract text after the pattern | |
idx = description.find(pattern) | |
camera_text = description[idx:].split('.')[0] # Take first sentence | |
if len(camera_text) > len(pattern) + 10: # Ensure meaningful content | |
return camera_text.strip() | |
return None | |
except Exception as e: | |
logger.warning(f"Failed to extract camera setup: {e}") | |
return None | |
def _save_temp_image(self, image: Image.Image) -> str: | |
"""Save image to temporary file for API call""" | |
try: | |
# Create temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
temp_path = temp_file.name | |
temp_file.close() | |
# Save image | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image.save(temp_path, 'PNG') | |
return temp_path | |
except Exception as e: | |
logger.error(f"Failed to save temporary image: {e}") | |
return None | |
def _cleanup_temp_file(self, file_path: str): | |
"""Clean up temporary file""" | |
try: | |
if file_path and os.path.exists(file_path): | |
os.unlink(file_path) | |
except Exception as e: | |
logger.warning(f"Failed to cleanup temp file: {e}") | |
def analyze_image(self, image: Image.Image, prompt: str = None) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image using BAGEL API""" | |
if not self.is_initialized: | |
success = self.initialize() | |
if not success: | |
return "BAGEL API not available", {"error": "API initialization failed"} | |
temp_path = None | |
try: | |
# Default prompt for detailed image analysis | |
if prompt is None: | |
prompt = """Analyze this image and provide a detailed description in two parts: | |
1. DESCRIPTION: Write a flowing paragraph describing the visual elements, composition, lighting, colors, artistic style, mood, and atmosphere. Start directly with the subject (e.g., "A black and white illustration..." not "The image shows..."). Write as a continuous narrative without numbered lists. | |
2. CAMERA_SETUP: Based on the photographic characteristics you observe, recommend the specific camera system, lens, and settings that would best capture this type of image. Consider focal length, aperture, lighting setup, and shooting style that matches what you see. | |
Format your response clearly with these two sections.""" | |
# Save image to temporary file | |
temp_path = self._save_temp_image(image) | |
if not temp_path: | |
return "Image processing failed", {"error": "Could not save image"} | |
logger.info("Calling BAGEL API for image analysis...") | |
# Call BAGEL API | |
result = self.client.predict( | |
image=handle_file(temp_path), | |
prompt=prompt, | |
show_thinking=False, | |
do_sample=False, | |
text_temperature=0.3, | |
max_new_tokens=512, | |
api_name=self.api_endpoint | |
) | |
# Extract response (API returns tuple: (image_result, text_response)) | |
if isinstance(result, tuple) and len(result) >= 2: | |
description = result[1] if result[1] else result[0] | |
else: | |
description = str(result) | |
# Clean up the description and extract camera setup if present | |
if isinstance(description, str) and description.strip(): | |
description = description.strip() | |
# Store camera setup separately if found | |
camera_setup = self._extract_camera_setup(description) | |
if camera_setup: | |
metadata["camera_setup"] = camera_setup | |
metadata["has_camera_suggestion"] = True | |
else: | |
metadata["has_camera_suggestion"] = False | |
else: | |
description = "Detailed image analysis completed successfully" | |
metadata["has_camera_suggestion"] = False | |
# Prepare metadata | |
metadata = { | |
"model": "BAGEL-7B-API", | |
"device": "api", | |
"confidence": 0.9, | |
"api_endpoint": self.api_endpoint, | |
"space_url": self.space_url, | |
"prompt_used": prompt, | |
"response_length": len(description) | |
} | |
logger.info(f"BAGEL API analysis complete: {len(description)} characters") | |
return description, metadata | |
except Exception as e: | |
logger.error(f"BAGEL API analysis failed: {e}") | |
return "API analysis failed", {"error": str(e), "model": "BAGEL-7B-API"} | |
finally: | |
# Always cleanup temporary file | |
if temp_path: | |
self._cleanup_temp_file(temp_path) | |
def analyze_for_flux_prompt(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image specifically for FLUX prompt generation""" | |
flux_prompt = """Analyze this image for FLUX generation and provide two sections: | |
1. DESCRIPTION: Create a detailed, flowing description suitable for FLUX generation. Write as a single coherent paragraph starting directly with the subject (avoid "The image shows..."). Describe photographic style, composition, lighting, colors, mood, and artistic elements. | |
2. CAMERA_SETUP: Recommend the specific professional camera system, lens, aperture, and technical settings that would recreate this exact image. Be specific about equipment brands, focal lengths, and shooting parameters based on the visual characteristics you observe. | |
Provide both sections clearly formatted.""" | |
return self.analyze_image(image, flux_prompt) | |
def cleanup(self) -> None: | |
"""Clean up API client resources""" | |
try: | |
if hasattr(self, 'client'): | |
self.client = None | |
super().cleanup() | |
logger.info("BAGEL API resources cleaned up") | |
except Exception as e: | |
logger.warning(f"BAGEL API cleanup warning: {e}") | |
class FallbackAnalyzer(BaseImageAnalyzer): | |
"""Simple fallback analyzer when BAGEL API is not available""" | |
def __init__(self): | |
super().__init__() | |
def initialize(self) -> bool: | |
"""Fallback is always ready""" | |
self.is_initialized = True | |
return True | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Provide basic image description""" | |
try: | |
# Basic image analysis | |
width, height = image.size | |
mode = image.mode | |
# Simple descriptive text based on image properties | |
aspect_ratio = width / height | |
if aspect_ratio > 1.5: | |
orientation = "landscape" | |
camera_suggestion = "wide-angle lens, landscape photography" | |
elif aspect_ratio < 0.75: | |
orientation = "portrait" | |
camera_suggestion = "portrait lens, shallow depth of field" | |
else: | |
orientation = "square" | |
camera_suggestion = "standard lens, balanced composition" | |
description = f"A {orientation} format image with professional composition. The image shows clear detail and good visual balance, suitable for high-quality reproduction. Recommended camera setup: {camera_suggestion}, professional lighting with careful attention to exposure and color balance." | |
metadata = { | |
"model": "Fallback", | |
"device": "cpu", | |
"confidence": 0.6, | |
"image_size": f"{width}x{height}", | |
"color_mode": mode, | |
"orientation": orientation, | |
"aspect_ratio": round(aspect_ratio, 2) | |
} | |
return description, metadata | |
except Exception as e: | |
logger.error(f"Fallback analysis failed: {e}") | |
return "Professional image suitable for detailed analysis and prompt generation", {"error": str(e), "model": "Fallback"} | |
class ModelManager: | |
"""Manager for handling image analysis models""" | |
def __init__(self, preferred_model: str = "bagel-api"): | |
self.preferred_model = preferred_model | |
self.analyzers = {} | |
self.current_analyzer = None | |
def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]: | |
"""Get or create analyzer for specified model""" | |
model_name = model_name or self.preferred_model | |
if model_name not in self.analyzers: | |
if model_name == "bagel-api": | |
self.analyzers[model_name] = BagelAPIAnalyzer() | |
elif model_name == "fallback": | |
self.analyzers[model_name] = FallbackAnalyzer() | |
else: | |
logger.warning(f"Unknown model: {model_name}, using fallback") | |
model_name = "fallback" | |
self.analyzers[model_name] = FallbackAnalyzer() | |
return self.analyzers[model_name] | |
def analyze_image(self, image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image with specified or preferred model""" | |
# Try preferred model first | |
analyzer = self.get_analyzer(model_name) | |
if analyzer is None: | |
return "No analyzer available", {"error": "Model not found"} | |
# Choose analysis method based on type | |
if analysis_type == "flux" and hasattr(analyzer, 'analyze_for_flux_prompt'): | |
success, result = safe_execute(analyzer.analyze_for_flux_prompt, image) | |
else: | |
success, result = safe_execute(analyzer.analyze_image, image) | |
if success and result[1].get("error") is None: | |
return result | |
else: | |
# Fallback to simple analyzer if main model fails | |
logger.warning(f"Primary model failed, using fallback: {result}") | |
fallback_analyzer = self.get_analyzer("fallback") | |
fallback_success, fallback_result = safe_execute(fallback_analyzer.analyze_image, image) | |
if fallback_success: | |
return fallback_result | |
else: | |
return "All analyzers failed", {"error": "Complete analysis failure"} | |
def cleanup_all(self) -> None: | |
"""Clean up all model resources""" | |
for analyzer in self.analyzers.values(): | |
analyzer.cleanup() | |
self.analyzers.clear() | |
clean_memory() | |
logger.info("All analyzers cleaned up") | |
# Global model manager instance | |
model_manager = ModelManager(preferred_model="bagel-api") | |
def analyze_image(image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: | |
""" | |
Convenience function for image analysis using BAGEL API | |
Args: | |
image: PIL Image to analyze | |
model_name: Optional model name ("bagel-api" or "fallback") | |
analysis_type: Type of analysis ("detailed" or "flux") | |
Returns: | |
Tuple of (description, metadata) | |
""" | |
return model_manager.analyze_image(image, model_name, analysis_type) | |
# Export main components | |
__all__ = [ | |
"BaseImageAnalyzer", | |
"BagelAPIAnalyzer", | |
"FallbackAnalyzer", | |
"ModelManager", | |
"model_manager", | |
"analyze_image" | |
] |