Phramer_AI / models.py
Malaji71's picture
Update models.py
24c3479 verified
raw
history blame
11.9 kB
"""
Model management for Frame 0 Laboratory for MIA
BAGEL 7B integration via API calls
"""
import logging
import tempfile
import os
from typing import Optional, Dict, Any, Tuple
from PIL import Image
from gradio_client import Client, handle_file
from config import get_device_config
from utils import clean_memory, safe_execute
logger = logging.getLogger(__name__)
class BaseImageAnalyzer:
"""Base class for image analysis models"""
def __init__(self):
self.is_initialized = False
self.device_config = get_device_config()
def initialize(self) -> bool:
"""Initialize the model"""
raise NotImplementedError
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
"""Analyze image and return description"""
raise NotImplementedError
def cleanup(self) -> None:
"""Clean up model resources"""
clean_memory()
class BagelAPIAnalyzer(BaseImageAnalyzer):
"""BAGEL 7B model via API calls to working Space"""
def __init__(self):
super().__init__()
self.client = None
self.space_url = "Malaji71/Bagel-7B-Demo"
self.api_endpoint = "/image_understanding"
def initialize(self) -> bool:
"""Initialize BAGEL API client"""
if self.is_initialized:
return True
try:
logger.info("Initializing BAGEL API client...")
self.client = Client(self.space_url)
self.is_initialized = True
logger.info("BAGEL API client initialized successfully")
return True
except Exception as e:
logger.error(f"BAGEL API client initialization failed: {e}")
return False
def _save_temp_image(self, image: Image.Image) -> str:
"""Save image to temporary file for API call"""
try:
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_path = temp_file.name
temp_file.close()
# Save image
if image.mode != 'RGB':
image = image.convert('RGB')
image.save(temp_path, 'PNG')
return temp_path
except Exception as e:
logger.error(f"Failed to save temporary image: {e}")
return None
def _cleanup_temp_file(self, file_path: str):
"""Clean up temporary file"""
try:
if file_path and os.path.exists(file_path):
os.unlink(file_path)
except Exception as e:
logger.warning(f"Failed to cleanup temp file: {e}")
def analyze_image(self, image: Image.Image, prompt: str = None) -> Tuple[str, Dict[str, Any]]:
"""Analyze image using BAGEL API"""
if not self.is_initialized:
success = self.initialize()
if not success:
return "BAGEL API not available", {"error": "API initialization failed"}
temp_path = None
try:
# Default prompt for detailed image analysis
if prompt is None:
prompt = "Provide a detailed description of this image, including objects, people, setting, composition, lighting, colors, mood, and artistic style. Focus on elements that would be useful for generating a similar image."
# Save image to temporary file
temp_path = self._save_temp_image(image)
if not temp_path:
return "Image processing failed", {"error": "Could not save image"}
logger.info("Calling BAGEL API for image analysis...")
# Call BAGEL API
result = self.client.predict(
image=handle_file(temp_path),
prompt=prompt,
show_thinking=False,
do_sample=False,
text_temperature=0.3,
max_new_tokens=512,
api_name=self.api_endpoint
)
# Extract response (API returns tuple: (image_result, text_response))
if isinstance(result, tuple) and len(result) >= 2:
description = result[1] if result[1] else result[0]
else:
description = str(result)
# Clean up the description
if isinstance(description, str) and description.strip():
description = description.strip()
else:
description = "Detailed image analysis completed successfully"
# Prepare metadata
metadata = {
"model": "BAGEL-7B-API",
"device": "api",
"confidence": 0.9,
"api_endpoint": self.api_endpoint,
"space_url": self.space_url,
"prompt_used": prompt,
"response_length": len(description)
}
logger.info(f"BAGEL API analysis complete: {len(description)} characters")
return description, metadata
except Exception as e:
logger.error(f"BAGEL API analysis failed: {e}")
return "API analysis failed", {"error": str(e), "model": "BAGEL-7B-API"}
finally:
# Always cleanup temporary file
if temp_path:
self._cleanup_temp_file(temp_path)
def analyze_for_flux_prompt(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
"""Analyze image specifically for FLUX prompt generation"""
flux_prompt = """Analyze this image and generate a detailed FLUX prompt description. Focus on:
- Photographic and artistic style
- Composition and framing
- Lighting conditions and mood
- Colors and visual elements
- Camera settings that would recreate this image
- Technical photography details
Provide a comprehensive description suitable for FLUX image generation."""
return self.analyze_image(image, flux_prompt)
def cleanup(self) -> None:
"""Clean up API client resources"""
try:
if hasattr(self, 'client'):
self.client = None
super().cleanup()
logger.info("BAGEL API resources cleaned up")
except Exception as e:
logger.warning(f"BAGEL API cleanup warning: {e}")
class FallbackAnalyzer(BaseImageAnalyzer):
"""Simple fallback analyzer when BAGEL API is not available"""
def __init__(self):
super().__init__()
def initialize(self) -> bool:
"""Fallback is always ready"""
self.is_initialized = True
return True
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]:
"""Provide basic image description"""
try:
# Basic image analysis
width, height = image.size
mode = image.mode
# Simple descriptive text based on image properties
aspect_ratio = width / height
if aspect_ratio > 1.5:
orientation = "landscape"
camera_suggestion = "wide-angle lens, landscape photography"
elif aspect_ratio < 0.75:
orientation = "portrait"
camera_suggestion = "portrait lens, shallow depth of field"
else:
orientation = "square"
camera_suggestion = "standard lens, balanced composition"
description = f"A {orientation} format image with professional composition. The image shows clear detail and good visual balance, suitable for high-quality reproduction. Recommended camera setup: {camera_suggestion}, professional lighting with careful attention to exposure and color balance."
metadata = {
"model": "Fallback",
"device": "cpu",
"confidence": 0.6,
"image_size": f"{width}x{height}",
"color_mode": mode,
"orientation": orientation,
"aspect_ratio": round(aspect_ratio, 2)
}
return description, metadata
except Exception as e:
logger.error(f"Fallback analysis failed: {e}")
return "Professional image suitable for detailed analysis and prompt generation", {"error": str(e), "model": "Fallback"}
class ModelManager:
"""Manager for handling image analysis models"""
def __init__(self, preferred_model: str = "bagel-api"):
self.preferred_model = preferred_model
self.analyzers = {}
self.current_analyzer = None
def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]:
"""Get or create analyzer for specified model"""
model_name = model_name or self.preferred_model
if model_name not in self.analyzers:
if model_name == "bagel-api":
self.analyzers[model_name] = BagelAPIAnalyzer()
elif model_name == "fallback":
self.analyzers[model_name] = FallbackAnalyzer()
else:
logger.warning(f"Unknown model: {model_name}, using fallback")
model_name = "fallback"
self.analyzers[model_name] = FallbackAnalyzer()
return self.analyzers[model_name]
def analyze_image(self, image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]:
"""Analyze image with specified or preferred model"""
# Try preferred model first
analyzer = self.get_analyzer(model_name)
if analyzer is None:
return "No analyzer available", {"error": "Model not found"}
# Choose analysis method based on type
if analysis_type == "flux" and hasattr(analyzer, 'analyze_for_flux_prompt'):
success, result = safe_execute(analyzer.analyze_for_flux_prompt, image)
else:
success, result = safe_execute(analyzer.analyze_image, image)
if success and result[1].get("error") is None:
return result
else:
# Fallback to simple analyzer if main model fails
logger.warning(f"Primary model failed, using fallback: {result}")
fallback_analyzer = self.get_analyzer("fallback")
fallback_success, fallback_result = safe_execute(fallback_analyzer.analyze_image, image)
if fallback_success:
return fallback_result
else:
return "All analyzers failed", {"error": "Complete analysis failure"}
def cleanup_all(self) -> None:
"""Clean up all model resources"""
for analyzer in self.analyzers.values():
analyzer.cleanup()
self.analyzers.clear()
clean_memory()
logger.info("All analyzers cleaned up")
# Global model manager instance
model_manager = ModelManager(preferred_model="bagel-api")
def analyze_image(image: Image.Image, model_name: str = None, analysis_type: str = "detailed") -> Tuple[str, Dict[str, Any]]:
"""
Convenience function for image analysis using BAGEL API
Args:
image: PIL Image to analyze
model_name: Optional model name ("bagel-api" or "fallback")
analysis_type: Type of analysis ("detailed" or "flux")
Returns:
Tuple of (description, metadata)
"""
return model_manager.analyze_image(image, model_name, analysis_type)
# Export main components
__all__ = [
"BaseImageAnalyzer",
"BagelAPIAnalyzer",
"FallbackAnalyzer",
"ModelManager",
"model_manager",
"analyze_image"
]