Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
from clip_interrogator import Config, Interrogator | |
import logging | |
import os | |
import warnings | |
from datetime import datetime | |
import gc | |
import re | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def get_device(): | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif torch.backends.mps.is_available(): | |
return "mps" | |
else: | |
return "cpu" | |
DEVICE = get_device() | |
class FluxRulesEngine: | |
""" | |
Flux prompt optimization based on Pariente AI research | |
Implements structured prompt generation following validated rules | |
""" | |
def __init__(self): | |
self.forbidden_elements = ["++", "weights", "white background [en dev]"] | |
self.articles = ["a", "an", "the"] | |
self.quality_adjectives = [ | |
"majestic", "pristine", "sleek", "elegant", "dramatic", | |
"cinematic", "professional", "stunning", "refined" | |
] | |
self.lighting_types = [ | |
"golden hour", "studio lighting", "dramatic lighting", | |
"ambient lighting", "natural light", "soft lighting", | |
"rim lighting", "volumetric lighting" | |
] | |
self.technical_specs = [ | |
"Shot on Phase One", "f/2.8 aperture", "50mm lens", | |
"85mm lens", "35mm lens", "professional photography", | |
"medium format", "high resolution" | |
] | |
self.materials = [ | |
"metallic", "glass", "chrome", "leather", "fabric", | |
"wood", "concrete", "steel", "ceramic" | |
] | |
def extract_subject(self, base_prompt): | |
"""Extract main subject from CLIP analysis""" | |
words = base_prompt.lower().split() | |
# Common subjects to identify | |
subjects = [ | |
"car", "vehicle", "automobile", "person", "man", "woman", | |
"building", "house", "landscape", "mountain", "tree", | |
"flower", "animal", "dog", "cat", "bird" | |
] | |
for word in words: | |
if word in subjects: | |
return word | |
# Fallback to first noun-like word | |
return words[0] if words else "subject" | |
def detect_setting(self, base_prompt): | |
"""Detect environmental context""" | |
prompt_lower = base_prompt.lower() | |
settings = { | |
"studio": ["studio", "backdrop", "seamless"], | |
"outdoor": ["outdoor", "outside", "landscape", "nature"], | |
"urban": ["city", "street", "urban", "building"], | |
"coastal": ["beach", "ocean", "coast", "sea"], | |
"indoor": ["room", "interior", "inside", "home"] | |
} | |
for setting, keywords in settings.items(): | |
if any(keyword in prompt_lower for keyword in keywords): | |
return setting | |
return "neutral environment" | |
def optimize_for_flux(self, base_prompt, style_preference="professional"): | |
"""Apply Flux-specific optimization rules""" | |
# Clean forbidden elements | |
cleaned_prompt = base_prompt | |
for forbidden in self.forbidden_elements: | |
cleaned_prompt = cleaned_prompt.replace(forbidden, "") | |
# Extract key elements | |
subject = self.extract_subject(base_prompt) | |
setting = self.detect_setting(base_prompt) | |
# Build structured prompt | |
components = [] | |
# 1. Article | |
article = "A" if subject[0] not in 'aeiou' else "An" | |
components.append(article) | |
# 2. Descriptive adjectives (max 2-3) | |
adjectives = ["elegant", "professional"] # Fixed instead of random | |
components.extend(adjectives) | |
# 3. Main subject | |
components.append(subject) | |
# 4. Verb/Action (gerund form) | |
if "person" in subject or "man" in subject or "woman" in subject: | |
action = "standing" | |
else: | |
action = "positioned" | |
components.append(action) | |
# 5. Context/Location | |
context_map = { | |
"studio": "in a professional studio setting", | |
"outdoor": "in a natural outdoor environment", | |
"urban": "on an urban street", | |
"coastal": "along a dramatic coastline", | |
"indoor": "in an elegant interior space" | |
} | |
components.append(context_map.get(setting, "in a carefully composed scene")) | |
# 6. Environmental details | |
components.append("with subtle atmospheric effects") | |
# 7. Materials/Textures (if applicable) | |
if any(mat in base_prompt.lower() for mat in ["car", "vehicle", "metal"]): | |
components.append("featuring metallic surfaces") | |
# 8. Lighting effects | |
components.append("illuminated by golden hour lighting") | |
# 9. Technical specs | |
components.append("Shot on Phase One, f/2.8 aperture") | |
# 10. Quality/Style | |
if style_preference == "cinematic": | |
quality = "cinematic composition" | |
elif style_preference == "commercial": | |
quality = "commercial photography quality" | |
else: | |
quality = "professional photography" | |
components.append(quality) | |
# Join components with proper punctuation | |
prompt = ", ".join(components) | |
# Capitalize first letter | |
prompt = prompt[0].upper() + prompt[1:] | |
return prompt | |
def get_optimization_score(self, prompt): | |
"""Calculate optimization score for Flux compatibility""" | |
score = 0 | |
# Structure check (order compliance) | |
if prompt.startswith(("A", "An", "The")): | |
score += 15 | |
# Technical specs presence | |
if any(spec in prompt for spec in self.technical_specs): | |
score += 20 | |
# Lighting specification | |
if any(light in prompt.lower() for light in self.lighting_types): | |
score += 15 | |
# No forbidden elements | |
if not any(forbidden in prompt for forbidden in self.forbidden_elements): | |
score += 15 | |
# Proper punctuation and structure | |
if "," in prompt: | |
score += 10 | |
# Length optimization | |
word_count = len(prompt.split()) | |
if 15 <= word_count <= 35: | |
score += 25 | |
elif 10 <= word_count <= 45: | |
score += 15 | |
return min(score, 100) | |
class FluxPromptOptimizer: | |
def __init__(self): | |
self.interrogator = None | |
self.flux_engine = FluxRulesEngine() | |
self.usage_count = 0 | |
self.device = DEVICE | |
self.is_initialized = False | |
def initialize_model(self): | |
if self.is_initialized: | |
return True | |
try: | |
config = Config( | |
clip_model_name="ViT-L-14/openai", | |
download_cache=True, | |
chunk_size=2048, | |
quiet=True, | |
device=self.device | |
) | |
self.interrogator = Interrogator(config) | |
self.is_initialized = True | |
if self.device == "cpu": | |
gc.collect() | |
else: | |
torch.cuda.empty_cache() | |
return True | |
except Exception as e: | |
logger.error(f"Initialization error: {e}") | |
return False | |
def optimize_image(self, image): | |
if image is None: | |
return None | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
elif not isinstance(image, Image.Image): | |
image = Image.open(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Optimize image size for processing | |
max_size = 768 if self.device != "cpu" else 512 | |
if image.size[0] > max_size or image.size[1] > max_size: | |
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
return image | |
def generate_optimized_prompt(self, image, style_preference="professional", mode="best"): | |
try: | |
if not self.is_initialized: | |
if not self.initialize_model(): | |
return "β Model initialization failed.", "Please refresh and try again.", 0 | |
if image is None: | |
return "β Please upload an image.", "No image provided.", 0 | |
self.usage_count += 1 | |
image = self.optimize_image(image) | |
if image is None: | |
return "β Image processing failed.", "Invalid image format.", 0 | |
start_time = datetime.now() | |
# Get base analysis from CLIP | |
try: | |
if mode == "fast": | |
base_prompt = self.interrogator.interrogate_fast(image) | |
elif mode == "classic": | |
base_prompt = self.interrogator.interrogate_classic(image) | |
else: | |
base_prompt = self.interrogator.interrogate(image) | |
except Exception as e: | |
base_prompt = self.interrogator.interrogate_fast(image) | |
# Apply Flux-specific optimization | |
optimized_prompt = self.flux_engine.optimize_for_flux(base_prompt, style_preference) | |
# Calculate optimization score | |
score = self.flux_engine.get_optimization_score(optimized_prompt) | |
end_time = datetime.now() | |
duration = (end_time - start_time).total_seconds() | |
# Memory cleanup | |
if self.device == "cpu": | |
gc.collect() | |
else: | |
torch.cuda.empty_cache() | |
# Generate analysis info | |
gpu_status = "β‘ ZeroGPU" if torch.cuda.is_available() else "π» CPU" | |
analysis_info = f"""**Analysis Complete** | |
**Processing:** {gpu_status} β’ {duration:.1f}s β’ {mode.title()} mode | |
**Style:** {style_preference.title()} photography | |
**Optimization Score:** {score}/100 | |
**Generation:** #{self.usage_count} | |
**Base Analysis:** {base_prompt[:100]}... | |
**Enhancement:** Applied Flux-specific structure and terminology""" | |
return optimized_prompt, analysis_info, score | |
except Exception as e: | |
logger.error(f"Generation error: {e}") | |
return f"β Error: {str(e)}", "Please try with a different image.", 0 | |
optimizer = FluxPromptOptimizer() | |
def process_image_wrapper(image, style_preference, mode): | |
"""Simple wrapper without progress callbacks""" | |
try: | |
prompt, info, score = optimizer.generate_optimized_prompt(image, style_preference, mode) | |
# Create score HTML | |
color = "#22c55e" if score >= 80 else "#f59e0b" if score >= 60 else "#ef4444" | |
score_html = f''' | |
<div style="text-align: center; padding: 1rem; background: linear-gradient(135deg, #f0fdf4 0%, #dcfce7 100%); border: 2px solid {color}; border-radius: 12px; margin: 1rem 0;"> | |
<div style="font-size: 2rem; font-weight: 700; color: {color}; margin: 0;">{score}</div> | |
<div style="font-size: 0.875rem; color: #15803d; margin: 0; text-transform: uppercase; letter-spacing: 0.05em;">Optimization Score</div> | |
</div> | |
''' | |
return prompt, info, score_html | |
except Exception as e: | |
logger.error(f"Wrapper error: {e}") | |
return "β Processing failed", f"Error: {str(e)}", '<div style="text-align: center; color: red;">Error</div>' | |
def clear_outputs(): | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return "", "", '<div style="text-align: center; padding: 1rem;"><div style="font-size: 2rem; color: #ccc;">--</div><div style="font-size: 0.875rem; color: #999;">Optimization Score</div></div>' | |
def create_interface(): | |
# Professional CSS with elegant typography | |
css = """ | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
.gradio-container { | |
max-width: 1200px !important; | |
margin: 0 auto !important; | |
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; | |
background: linear-gradient(135deg, #f8fafc 0%, #f1f5f9 100%) !important; | |
} | |
.main-header { | |
text-align: center; | |
padding: 2rem 0 3rem 0; | |
background: linear-gradient(135deg, #1e293b 0%, #334155 100%); | |
color: white; | |
margin: -2rem -2rem 2rem -2rem; | |
border-radius: 0 0 24px 24px; | |
} | |
.main-title { | |
font-size: 2.5rem !important; | |
font-weight: 700 !important; | |
margin: 0 0 0.5rem 0 !important; | |
letter-spacing: -0.025em !important; | |
background: linear-gradient(135deg, #60a5fa 0%, #3b82f6 100%); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
background-clip: text; | |
} | |
.subtitle { | |
font-size: 1.125rem !important; | |
font-weight: 400 !important; | |
opacity: 0.8 !important; | |
margin: 0 !important; | |
} | |
.prompt-output { | |
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace !important; | |
font-size: 14px !important; | |
line-height: 1.6 !important; | |
background: linear-gradient(135deg, #ffffff 0%, #f8fafc 100%) !important; | |
border: 1px solid #e2e8f0 !important; | |
border-radius: 12px !important; | |
padding: 1.5rem !important; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1) !important; | |
} | |
""" | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
title="Flux Prompt Optimizer", | |
css=css | |
) as interface: | |
gr.HTML(""" | |
<div class="main-header"> | |
<div class="main-title">β‘ Flux Prompt Optimizer</div> | |
<div class="subtitle">Advanced prompt generation for Flux models β’ Research-based optimization</div> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## π· Image Input") | |
image_input = gr.Image( | |
label="Upload your image", | |
type="pil", | |
height=320 | |
) | |
gr.Markdown("## βοΈ Settings") | |
style_selector = gr.Dropdown( | |
choices=["professional", "cinematic", "commercial", "artistic"], | |
value="professional", | |
label="Photography Style" | |
) | |
mode_selector = gr.Dropdown( | |
choices=["fast", "classic", "best"], | |
value="best", | |
label="Analysis Mode" | |
) | |
optimize_btn = gr.Button( | |
"π Generate Optimized Prompt", | |
variant="primary", | |
size="lg" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("## π Optimized Prompt") | |
prompt_output = gr.Textbox( | |
label="Generated Prompt", | |
placeholder="Your optimized Flux prompt will appear here...", | |
lines=6, | |
max_lines=10, | |
elem_classes=["prompt-output"], | |
show_copy_button=True | |
) | |
score_output = gr.HTML( | |
value='<div style="text-align: center; padding: 1rem;"><div style="font-size: 2rem; color: #ccc;">--</div><div style="font-size: 0.875rem; color: #999;">Optimization Score</div></div>' | |
) | |
info_output = gr.Markdown(value="") | |
with gr.Row(): | |
clear_btn = gr.Button("ποΈ Clear", size="sm") | |
gr.Markdown(""" | |
--- | |
### π¬ Research Foundation | |
Flux Prompt Optimizer implements validated prompt engineering research for optimal Flux model performance. | |
The optimization engine applies structured composition rules, technical terminology, and quality markers | |
specifically calibrated for Flux architecture. | |
**Developed by Pariente AI** β’ Advanced AI Research Laboratory | |
""") | |
# Event handlers - FIXED | |
optimize_btn.click( | |
fn=process_image_wrapper, | |
inputs=[image_input, style_selector, mode_selector], | |
outputs=[prompt_output, info_output, score_output] | |
) | |
clear_btn.click( | |
fn=clear_outputs, | |
outputs=[prompt_output, info_output, score_output] | |
) | |
return interface | |
if __name__ == "__main__": | |
logger.info("π Starting Flux Prompt Optimizer") | |
interface = create_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |