Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Model management for FLUX Prompt Optimizer | |
Handles Florence-2 and Bagel model integration | |
""" | |
import logging | |
import requests | |
import spaces | |
import torch | |
from typing import Optional, Dict, Any, Tuple | |
from PIL import Image | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from config import MODEL_CONFIG, get_device_config | |
from utils import clean_memory, safe_execute | |
logger = logging.getLogger(__name__) | |
class BaseImageAnalyzer: | |
"""Base class for image analysis models""" | |
def __init__(self): | |
self.model = None | |
self.processor = None | |
self.device_config = get_device_config() | |
self.is_initialized = False | |
def initialize(self) -> bool: | |
"""Initialize the model""" | |
raise NotImplementedError | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image and return description""" | |
raise NotImplementedError | |
def cleanup(self) -> None: | |
"""Clean up model resources""" | |
if self.model is not None: | |
del self.model | |
self.model = None | |
if self.processor is not None: | |
del self.processor | |
self.processor = None | |
clean_memory() | |
class Florence2Analyzer(BaseImageAnalyzer): | |
"""Florence-2 model for image analysis""" | |
def __init__(self): | |
super().__init__() | |
self.config = MODEL_CONFIG["florence2"] | |
def initialize(self) -> bool: | |
"""Initialize Florence-2 model""" | |
if self.is_initialized: | |
return True | |
try: | |
logger.info("Initializing Florence-2 model...") | |
model_id = self.config["model_id"] | |
# Load processor | |
self.processor = AutoProcessor.from_pretrained( | |
model_id, | |
trust_remote_code=self.config["trust_remote_code"] | |
) | |
# Load model | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
trust_remote_code=self.config["trust_remote_code"], | |
torch_dtype=self.config["torch_dtype"] if self.device_config["use_gpu"] else torch.float32 | |
) | |
# Move to appropriate device | |
if self.device_config["use_gpu"]: | |
self.model = self.model.to(self.device_config["device"]) | |
else: | |
self.model = self.model.to("cpu") | |
self.model.eval() | |
self.is_initialized = True | |
logger.info(f"Florence-2 initialized on {self.device_config['device']}") | |
return True | |
except Exception as e: | |
logger.error(f"Florence-2 initialization failed: {e}") | |
self.cleanup() | |
return False | |
def _gpu_inference(self, image: Image.Image, task_prompt: str) -> str: | |
"""Run inference on GPU with spaces decorator""" | |
try: | |
# Move model to GPU for inference | |
if self.device_config["use_gpu"]: | |
self.model = self.model.to("cuda") | |
# Prepare inputs | |
inputs = self.processor(text=task_prompt, images=image, return_tensors="pt") | |
# Move inputs to device | |
device = "cuda" if self.device_config["use_gpu"] else self.device_config["device"] | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
# Generate response | |
with torch.no_grad(): | |
if self.device_config["use_gpu"]: | |
with torch.cuda.amp.autocast(dtype=torch.float16): | |
generated_ids = self.model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=self.config["max_new_tokens"], | |
num_beams=3, | |
do_sample=False | |
) | |
else: | |
generated_ids = self.model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=self.config["max_new_tokens"], | |
num_beams=3, | |
do_sample=False | |
) | |
# Decode response | |
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed = self.processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
# Extract caption | |
if task_prompt in parsed: | |
return parsed[task_prompt] | |
else: | |
return str(parsed) if parsed else "" | |
except Exception as e: | |
logger.error(f"Florence-2 GPU inference failed: {e}") | |
return "" | |
finally: | |
# Move model back to CPU to free GPU memory | |
if self.device_config["use_gpu"]: | |
self.model = self.model.to("cpu") | |
clean_memory() | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image using Florence-2""" | |
if not self.is_initialized: | |
success = self.initialize() | |
if not success: | |
return "Model initialization failed", {"error": "Florence-2 not available"} | |
try: | |
# Define analysis tasks | |
tasks = { | |
"detailed": "<DETAILED_CAPTION>", | |
"more_detailed": "<MORE_DETAILED_CAPTION>", | |
"caption": "<CAPTION>" | |
} | |
results = {} | |
# Run analysis for each task | |
for task_name, task_prompt in tasks.items(): | |
if self.device_config["use_gpu"]: | |
result = self._gpu_inference(image, task_prompt) | |
else: | |
result = self._cpu_inference(image, task_prompt) | |
results[task_name] = result | |
# Choose best result | |
if results["more_detailed"]: | |
main_description = results["more_detailed"] | |
elif results["detailed"]: | |
main_description = results["detailed"] | |
else: | |
main_description = results["caption"] or "A photograph" | |
# Prepare metadata | |
metadata = { | |
"model": "Florence-2", | |
"device": self.device_config["device"], | |
"all_results": results, | |
"confidence": 0.85 # Florence-2 generally reliable | |
} | |
logger.info(f"Florence-2 analysis complete: {len(main_description)} chars") | |
return main_description, metadata | |
except Exception as e: | |
logger.error(f"Florence-2 analysis failed: {e}") | |
return "Analysis failed", {"error": str(e)} | |
def _cpu_inference(self, image: Image.Image, task_prompt: str) -> str: | |
"""Run inference on CPU""" | |
try: | |
inputs = self.processor(text=task_prompt, images=image, return_tensors="pt") | |
with torch.no_grad(): | |
generated_ids = self.model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=self.config["max_new_tokens"], | |
num_beams=2, # Reduced for CPU | |
do_sample=False | |
) | |
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
parsed = self.processor.post_process_generation( | |
generated_text, | |
task=task_prompt, | |
image_size=(image.width, image.height) | |
) | |
if task_prompt in parsed: | |
return parsed[task_prompt] | |
else: | |
return str(parsed) if parsed else "" | |
except Exception as e: | |
logger.error(f"Florence-2 CPU inference failed: {e}") | |
return "" | |
class BagelAnalyzer(BaseImageAnalyzer): | |
"""Bagel-7B model analyzer via API""" | |
def __init__(self): | |
super().__init__() | |
self.config = MODEL_CONFIG["bagel"] | |
self.session = requests.Session() | |
def initialize(self) -> bool: | |
"""Initialize Bagel analyzer (API-based)""" | |
try: | |
# Test API connectivity | |
test_response = self.session.get( | |
self.config["api_url"], | |
timeout=self.config["timeout"] | |
) | |
if test_response.status_code == 200: | |
self.is_initialized = True | |
logger.info("Bagel API connection established") | |
return True | |
else: | |
logger.error(f"Bagel API not accessible: {test_response.status_code}") | |
return False | |
except Exception as e: | |
logger.error(f"Bagel initialization failed: {e}") | |
return False | |
def analyze_image(self, image: Image.Image) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image using Bagel-7B API""" | |
if not self.is_initialized: | |
success = self.initialize() | |
if not success: | |
return "Bagel API not available", {"error": "API connection failed"} | |
try: | |
# Convert image to base64 or prepare for API call | |
# Note: This is a placeholder - actual implementation would depend on Bagel API format | |
# For now, return a placeholder response | |
# In real implementation, you would: | |
# 1. Convert image to required format | |
# 2. Make API call to Bagel endpoint | |
# 3. Parse response | |
description = "Detailed image analysis via Bagel-7B (API implementation needed)" | |
metadata = { | |
"model": "Bagel-7B", | |
"method": "API", | |
"confidence": 0.8 | |
} | |
logger.info("Bagel analysis complete (placeholder)") | |
return description, metadata | |
except Exception as e: | |
logger.error(f"Bagel analysis failed: {e}") | |
return "Analysis failed", {"error": str(e)} | |
class ModelManager: | |
"""Manager for handling multiple analysis models""" | |
def __init__(self, preferred_model: str = None): | |
self.preferred_model = preferred_model or MODEL_CONFIG["primary_model"] | |
self.analyzers = {} | |
self.current_analyzer = None | |
def get_analyzer(self, model_name: str = None) -> Optional[BaseImageAnalyzer]: | |
"""Get or create analyzer for specified model""" | |
model_name = model_name or self.preferred_model | |
if model_name not in self.analyzers: | |
if model_name == "florence2": | |
self.analyzers[model_name] = Florence2Analyzer() | |
elif model_name == "bagel": | |
self.analyzers[model_name] = BagelAnalyzer() | |
else: | |
logger.error(f"Unknown model: {model_name}") | |
return None | |
return self.analyzers[model_name] | |
def analyze_image(self, image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]: | |
"""Analyze image with specified or preferred model""" | |
analyzer = self.get_analyzer(model_name) | |
if analyzer is None: | |
return "No analyzer available", {"error": "Model not found"} | |
success, result = safe_execute(analyzer.analyze_image, image) | |
if success: | |
return result | |
else: | |
return "Analysis failed", {"error": result} | |
def cleanup_all(self) -> None: | |
"""Clean up all model resources""" | |
for analyzer in self.analyzers.values(): | |
analyzer.cleanup() | |
self.analyzers.clear() | |
clean_memory() | |
# Global model manager instance | |
model_manager = ModelManager() | |
def analyze_image(image: Image.Image, model_name: str = None) -> Tuple[str, Dict[str, Any]]: | |
""" | |
Convenience function for image analysis | |
Args: | |
image: PIL Image to analyze | |
model_name: Optional model name ("florence2" or "bagel") | |
Returns: | |
Tuple of (description, metadata) | |
""" | |
return model_manager.analyze_image(image, model_name) | |
# Export main components | |
__all__ = [ | |
"BaseImageAnalyzer", | |
"Florence2Analyzer", | |
"BagelAnalyzer", | |
"ModelManager", | |
"model_manager", | |
"analyze_image" | |
] |