Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Model management for Frame 0 Laboratory for MIA | |
BAGEL 7B integration via API calls | |
""" | |
import spaces | |
import logging | |
import tempfile | |
import os | |
import re | |
from typing import Optional, Dict, Any, Tuple | |
from PIL import Image | |
from gradio_client import Client, handle_file | |
from config import get_device_config | |
from utils import clean_memory, safe_execute | |
logger = logging.getLogger(__name__) | |
class BaseImageAnalyzer: | |
"""Base class for image analysis models""" | |
def __init__(self): | |
self.is_initialized = False | |
self.device_config = get_device_config() | |
def initialize(self) -> bool: | |
"""Initialize the model""" | |
raise NotImplementedError | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image and return description""" | |
raise NotImplementedError | |
def cleanup(self) -> None: | |
"""Clean up model resources""" | |
clean_memory() | |
class BagelAPIAnalyzer(BaseImageAnalyzer): | |
"""BAGEL 7B model via API calls to working Space""" | |
def __init__(self): | |
super().__init__() | |
self.client = None | |
self.space_url = "Malaji71/Bagel-7B-Demo" | |
self.api_endpoint = "/image_understanding" | |
self.hf_token = os.getenv("HF_TOKEN") # Get token from environment/secrets | |
def initialize(self) -> bool: | |
"""Initialize BAGEL API client with authentication""" | |
if self.is_initialized: | |
return True | |
try: | |
logger.info("Initializing BAGEL API client...") | |
# Initialize client with token if available (for private spaces) | |
if self.hf_token: | |
logger.info("Using HF token for private space access") | |
self.client = Client(self.space_url, hf_token=self.hf_token) | |
else: | |
logger.info("No HF token found, accessing public space") | |
self.client = Client(self.space_url) | |
self.is_initialized = True | |
logger.info("BAGEL API client initialized successfully") | |
return True | |
except Exception as e: | |
logger.error(f"BAGEL API client initialization failed: {e}") | |
# If private space fails, try without token as fallback | |
if self.hf_token: | |
logger.info("Retrying without token...") | |
try: | |
self.client = Client(self.space_url) | |
self.is_initialized = True | |
logger.info("BAGEL API client initialized successfully (fallback to public)") | |
return True | |
except Exception as e2: | |
logger.error(f"Fallback initialization also failed: {e2}") | |
return False | |
def _extract_camera_setup(self, description: str) -> Optional[str]: | |
"""Extract camera setup recommendation from BAGEL response with improved parsing""" | |
try: | |
# Look for CAMERA_SETUP section first | |
if "CAMERA_SETUP:" in description: | |
parts = description.split("CAMERA_SETUP:") | |
if len(parts) > 1: | |
camera_section = parts[1].strip() | |
# Take the first meaningful sentence from camera setup | |
camera_text = camera_section.split('\n')[0].strip() | |
if len(camera_text) > 20: # Ensure meaningful content | |
return self._parse_camera_recommendation(camera_text) | |
# Look for "2. CAMERA_SETUP" pattern | |
if "2. CAMERA_SETUP" in description: | |
parts = description.split("2. CAMERA_SETUP") | |
if len(parts) > 1: | |
camera_section = parts[1].strip() | |
camera_text = camera_section.split('\n')[0].strip() | |
if len(camera_text) > 20: | |
return self._parse_camera_recommendation(camera_text) | |
# Look for camera recommendations within the text | |
camera_recommendation = self._find_camera_recommendation(description) | |
if camera_recommendation: | |
return camera_recommendation | |
return None | |
except Exception as e: | |
logger.warning(f"Failed to extract camera setup: {e}") | |
return None | |
def _parse_camera_recommendation(self, camera_text: str) -> Optional[str]: | |
"""Parse and extract specific camera and lens information""" | |
try: | |
# Remove common prefixes and clean text | |
camera_text = re.sub(r'^(Based on.*?recommend|I would recommend|For this.*?recommend)\s*', '', camera_text, flags=re.IGNORECASE) | |
camera_text = re.sub(r'^(using a|use a|cameras? like)\s*', '', camera_text, flags=re.IGNORECASE) | |
# Extract camera model with specific patterns | |
camera_patterns = [ | |
r'(Canon EOS [R\d]+[^\s,]*(?:\s+[IVX]+)?)', | |
r'(Sony A[^\s,]+(?:\s+[IVX]+)?)', | |
r'(Leica [^\s,]+)', | |
r'(Hasselblad [^\s,]+)', | |
r'(Phase One [^\s,]+)', | |
r'(Fujifilm [^\s,]+)' | |
] | |
camera_model = None | |
for pattern in camera_patterns: | |
match = re.search(pattern, camera_text, re.IGNORECASE) | |
if match: | |
camera_model = match.group(1).strip() | |
break | |
# Extract lens information with improved patterns | |
lens_patterns = [ | |
r'(\d+mm\s*f/[\d.]+(?:\s*lens)?)', | |
r'(\d+-\d+mm\s*f/[\d.]+(?:\s*lens)?)', | |
r'(with\s+(?:a\s+)?(\d+mm[^,.]*))', | |
r'(paired with.*?(\d+mm[^,.]*))' | |
] | |
lens_info = None | |
for pattern in lens_patterns: | |
match = re.search(pattern, camera_text, re.IGNORECASE) | |
if match: | |
lens_info = match.group(1).strip() | |
lens_info = re.sub(r'^(with\s+(?:a\s+)?|paired with\s+)', '', lens_info, flags=re.IGNORECASE) | |
break | |
# Extract aperture if not in lens info | |
if not lens_info or 'f/' not in lens_info: | |
aperture_match = re.search(r'(f/[\d.]+)', camera_text) | |
aperture = aperture_match.group(1) if aperture_match else None | |
if aperture and lens_info: | |
lens_info = f"{lens_info} {aperture}" | |
# Build clean recommendation | |
parts = [] | |
if camera_model: | |
parts.append(camera_model) | |
if lens_info: | |
parts.append(lens_info) | |
if parts: | |
result = ', '.join(parts) | |
logger.info(f"Parsed camera recommendation: {result}") | |
return result | |
return None | |
except Exception as e: | |
logger.warning(f"Failed to parse camera recommendation: {e}") | |
return None | |
def _find_camera_recommendation(self, text: str) -> Optional[str]: | |
"""Find camera recommendations anywhere in the text""" | |
try: | |
# Look for sentences containing camera info | |
sentences = re.split(r'[.!?]', text) | |
for sentence in sentences: | |
# Check if sentence contains camera info | |
if any(brand in sentence.lower() for brand in ['canon', 'sony', 'leica', 'hasselblad', 'phase one', 'fujifilm']): | |
if any(term in sentence.lower() for term in ['recommend', 'suggest', 'would use', 'camera', 'lens']): | |
parsed = self._parse_camera_recommendation(sentence.strip()) | |
if parsed: | |
return parsed | |
return None | |
except Exception as e: | |
logger.warning(f"Failed to find camera recommendation: {e}") | |
return None | |
def _save_temp_image(self, image: Image.Image) -> str: | |
"""Save image to temporary file for API call""" | |
try: | |
# Create temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') | |
temp_path = temp_file.name | |
temp_file.close() | |
# Save image | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
image.save(temp_path, 'PNG') | |
return temp_path | |
except Exception as e: | |
logger.error(f"Failed to save temporary image: {e}") | |
return None | |
def _cleanup_temp_file(self, file_path: str): | |
"""Clean up temporary file""" | |
try: | |
if file_path and os.path.exists(file_path): | |
os.unlink(file_path) | |
except Exception as e: | |
logger.warning(f"Failed to cleanup temp file: {e}") | |
def analyze_image(self, image: Image.Image, prompt: str = None) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image using BAGEL API""" | |
if not self.is_initialized: | |
success = self.initialize() | |
if not success: | |
return "BAGEL API not available", {"error": "API initialization failed"} | |
temp_path = None | |
# Initialize metadata early | |
metadata = { | |
"model": "BAGEL-7B-API", | |
"device": "api", | |
"confidence": 0.9, | |
"api_endpoint": self.api_endpoint, | |
"space_url": self.space_url, | |
"prompt_used": prompt, | |
"has_camera_suggestion": False | |
} | |
try: | |
# Enhanced prompt for better structured output | |
if prompt is None: | |
prompt = """Analyze this image for professional photography reproduction. Provide exactly two sections: | |
1. DESCRIPTION: Write a single flowing paragraph describing what you see. Start directly with the subject (e.g., "A color photograph showing..." or "A black and white image depicting..."). Include: | |
- Image type (photograph, illustration, artwork) | |
- Subject and composition | |
- Color palette and lighting conditions | |
- Mood and atmosphere | |
- Photographic style and format | |
2. CAMERA_SETUP: Based on the scene type you observe, recommend ONE specific professional camera and lens combination: | |
- For street/documentary scenes: Canon EOS R6 with 35mm f/1.4 lens | |
- For portrait photography: Canon EOS R5 with 85mm f/1.4 lens | |
- For landscape photography: Phase One XT with 24-70mm f/4 lens | |
- For action/sports: Sony A1 with 70-200mm f/2.8 lens | |
Give only the camera model and lens specification, nothing else.""" | |
# Save image to temporary file | |
temp_path = self._save_temp_image(image) | |
if not temp_path: | |
return "Image processing failed", {"error": "Could not save image"} | |
logger.info("Calling BAGEL API for image analysis...") | |
# Call BAGEL API | |
result = self.client.predict( | |
image=handle_file(temp_path), | |
prompt=prompt, | |
show_thinking=False, | |
do_sample=False, | |
text_temperature=0.2, | |
max_new_tokens=512, | |
api_name=self.api_endpoint | |
) | |
# Extract response (API returns tuple: (image_result, text_response)) | |
if isinstance(result, tuple) and len(result) >= 2: | |
description = result[1] if result[1] else result[0] | |
else: | |
description = str(result) | |
# Process the description and extract camera setup | |
if isinstance(description, str) and description.strip(): | |
description = description.strip() | |
# Extract camera setup with improved parsing | |
camera_setup = self._extract_camera_setup(description) | |
if camera_setup: | |
metadata["camera_setup"] = camera_setup | |
metadata["has_camera_suggestion"] = True | |
logger.info(f"Extracted camera setup: {camera_setup}") | |
else: | |
metadata["has_camera_suggestion"] = False | |
logger.warning("No valid camera setup found in BAGEL response") | |
else: | |
description = "Detailed image analysis completed successfully" | |
metadata["has_camera_suggestion"] = False | |
# Update final metadata | |
metadata.update({ | |
"response_length": len(description) | |
}) | |
logger.info(f"BAGEL API analysis complete: {len(description)} characters, Camera: {metadata.get('has_camera_suggestion', False)}") | |
return description, metadata | |
except Exception as e: | |
logger.error(f"BAGEL API analysis failed: {e}") | |
return "API analysis failed", {"error": str(e), "model": "BAGEL-7B-API"} | |
finally: | |
# Always cleanup temporary file | |
if temp_path: | |
self._cleanup_temp_file(temp_path) | |
def analyze_for_flux_prompt(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image specifically for FLUX prompt generation""" | |
flux_prompt = """Analyze this image for professional FLUX generation. Provide exactly two sections: | |
1. DESCRIPTION: Create a single flowing paragraph starting directly with the subject. Be precise about: | |
- Image type (photograph, illustration, artwork) | |
- Subject matter and composition | |
- Color palette (specific colors, warm/cool tones, monochrome) | |
- Lighting conditions and photographic style | |
- Mood, atmosphere, and artistic elements | |
2. CAMERA_SETUP: Recommend ONE specific professional camera and lens for this scene type: | |
- Street/urban/documentary: Canon EOS R6 with 35mm f/1.4 lens | |
- Portrait photography: Canon EOS R5 with 85mm f/1.4 lens | |
- Landscape photography: Phase One XT with 24-70mm f/4 lens | |
- Action/sports: Sony A1 with 70-200mm f/2.8 lens | |
Give only the camera model and exact lens specification.""" | |
return self.analyze_image(image, flux_prompt) | |
def cleanup(self) -> None: | |
"""Clean up API client resources""" | |
try: | |
if hasattr(self, 'client'): | |
self.client = None | |
super().cleanup() | |
logger.info("BAGEL API resources cleaned up") | |
except Exception as e: | |
logger.warning(f"BAGEL API cleanup warning: {e}") | |
class FallbackAnalyzer(BaseImageAnalyzer): | |
"""Simple fallback analyzer when BAGEL API is not available""" | |
def __init__(self): | |
super().__init__() | |
def initialize(self) -> bool: | |
"""Fallback is always ready""" | |
self.is_initialized = True | |
return True | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Provide basic image description""" | |
try: | |
# Basic image analysis | |
width, height = image.size | |
mode = image.mode | |
# Simple descriptive text based on image properties | |
aspect_ratio = width / height | |
if aspect_ratio > 1.5: | |
orientation = "landscape" | |
camera_suggestion = "wide-angle lens, landscape photography" | |
elif aspect_ratio < 0.75: | |
orientation = "portrait" | |
camera_suggestion = "portrait lens, shallow depth of field" | |
else: | |
orientation = "square" | |
camera_suggestion = "standard lens, balanced composition" | |
description = f"A {orientation} format image with professional composition. The image shows clear detail and good visual balance, suitable for high-quality reproduction. Recommended camera setup: {camera_suggestion}, professional lighting with careful attention to exposure and color balance." | |
metadata = { | |
"model": "Fallback", | |
"device": "cpu", | |
"confidence": 0.6, | |
"image_size": f"{width}x{height}", | |
"color_mode": mode, | |
"orientation": orientation, | |
"aspect_ratio": round(aspect_ratio, 2) | |
} | |
return description, metadata | |
except Exception as e: | |
logger.error(f"Fallback analysis failed: {e}") | |
return "Professional image suitable for detailed analysis and prompt generation", {"error": str(e), "model": "Fallback"} | |
class ModelManager: | |
"""Manager for handling image analysis models""" | |
def __init__(self, preferred_model: str = "bagel-api"): | |
self.preferred_model = preferred_model | |
self.analyzers = {} | |
self.current_analyzer = None | |
def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]: | |
"""Get or create analyzer for specified model""" | |
model_name = model_name or self.preferred_model | |
if model_name not in self.analyzers: | |
if model_name == "bagel-api": | |
self.analyzers[model_name] = BagelAPIAnalyzer() | |
elif model_name == "fallback": | |
self.analyzers[model_name] = FallbackAnalyzer() | |
else: | |
logger.warning(f"Unknown model: {model_name}, using fallback") | |
model_name = "fallback" | |
self.analyzers[model_name] = FallbackAnalyzer() | |
return self.analyzers[model_name] | |
def analyze_image(self, image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image with specified or preferred model""" | |
# Try preferred model first | |
analyzer = self.get_analyzer(model_name) | |
if analyzer is None: | |
return "No analyzer available", {"error": "Model not found"} | |
# Choose analysis method based on type | |
if analysis_type == "flux" and hasattr(analyzer, 'analyze_for_flux_prompt'): | |
success, result = safe_execute(analyzer.analyze_for_flux_prompt, image) | |
else: | |
success, result = safe_execute(analyzer.analyze_image, image) | |
if success and result[1].get("error") is None: | |
return result | |
else: | |
# Fallback to simple analyzer if main model fails | |
logger.warning(f"Primary model failed, using fallback: {result}") | |
fallback_analyzer = self.get_analyzer("fallback") | |
fallback_success, fallback_result = safe_execute(fallback_analyzer.analyze_image, image) | |
if fallback_success: | |
return fallback_result | |
else: | |
return "All analyzers failed", {"error": "Complete analysis failure"} | |
def cleanup_all(self) -> None: | |
"""Clean up all model resources""" | |
for analyzer in self.analyzers.values(): | |
analyzer.cleanup() | |
self.analyzers.clear() | |
clean_memory() | |
logger.info("All analyzers cleaned up") | |
# Global model manager instance | |
model_manager = ModelManager(preferred_model="bagel-api") | |
def analyze_image(image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]: | |
""" | |
Convenience function for image analysis using BAGEL API | |
Args: | |
image: PIL Image to analyze | |
model_name: Optional model name ("bagel-api" or "fallback") | |
analysis_type: Type of analysis ("detailed" or "flux") | |
Returns: | |
Tuple of (description, metadata) | |
""" | |
return model_manager.analyze_image(image, model_name, analysis_type) | |
# Export main components | |
__all__ = [ | |
"BaseImageAnalyzer", | |
"BagelAPIAnalyzer", | |
"FallbackAnalyzer", | |
"ModelManager", | |
"model_manager", | |
"analyze_image" | |
] |