Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
from clip_interrogator import Config, Interrogator | |
import logging | |
import os | |
import warnings | |
from datetime import datetime | |
import gc | |
import re | |
# Suppress warnings | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
warnings.filterwarnings("ignore", category=UserWarning) | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
def get_device(): | |
if torch.cuda.is_available(): | |
return "cuda" | |
elif torch.backends.mps.is_available(): | |
return "mps" | |
else: | |
return "cpu" | |
DEVICE = get_device() | |
class FluxRulesEngine: | |
""" | |
Flux prompt optimization based on Pariente AI research | |
Implements structured prompt generation following validated rules | |
""" | |
def __init__(self): | |
self.forbidden_elements = ["++", "weights", "white background [en dev]"] | |
self.structure_order = { | |
1: "article", | |
2: "descriptive_adjectives", | |
3: "main_subject", | |
4: "verb_action", | |
5: "context_location", | |
6: "environmental_details", | |
7: "materials_textures", | |
8: "lighting_effects", | |
9: "technical_specs", | |
10: "quality_style" | |
} | |
self.articles = ["a", "an", "the"] | |
self.quality_adjectives = [ | |
"majestic", "pristine", "sleek", "elegant", "dramatic", | |
"cinematic", "professional", "stunning", "refined" | |
] | |
self.lighting_types = [ | |
"golden hour", "studio lighting", "dramatic lighting", | |
"ambient lighting", "natural light", "soft lighting", | |
"rim lighting", "volumetric lighting" | |
] | |
self.technical_specs = [ | |
"Shot on Phase One", "f/2.8 aperture", "50mm lens", | |
"85mm lens", "35mm lens", "professional photography", | |
"medium format", "high resolution" | |
] | |
self.materials = [ | |
"metallic", "glass", "chrome", "leather", "fabric", | |
"wood", "concrete", "steel", "ceramic" | |
] | |
def extract_subject(self, base_prompt): | |
"""Extract main subject from CLIP analysis""" | |
words = base_prompt.lower().split() | |
# Common subjects to identify | |
subjects = [ | |
"car", "vehicle", "automobile", "person", "man", "woman", | |
"building", "house", "landscape", "mountain", "tree", | |
"flower", "animal", "dog", "cat", "bird" | |
] | |
for word in words: | |
if word in subjects: | |
return word | |
# Fallback to first noun-like word | |
return words[0] if words else "subject" | |
def detect_setting(self, base_prompt): | |
"""Detect environmental context""" | |
prompt_lower = base_prompt.lower() | |
settings = { | |
"studio": ["studio", "backdrop", "seamless"], | |
"outdoor": ["outdoor", "outside", "landscape", "nature"], | |
"urban": ["city", "street", "urban", "building"], | |
"coastal": ["beach", "ocean", "coast", "sea"], | |
"indoor": ["room", "interior", "inside", "home"] | |
} | |
for setting, keywords in settings.items(): | |
if any(keyword in prompt_lower for keyword in keywords): | |
return setting | |
return "neutral environment" | |
def optimize_for_flux(self, base_prompt, style_preference="professional"): | |
"""Apply Flux-specific optimization rules""" | |
# Clean forbidden elements | |
cleaned_prompt = base_prompt | |
for forbidden in self.forbidden_elements: | |
cleaned_prompt = cleaned_prompt.replace(forbidden, "") | |
# Extract key elements | |
subject = self.extract_subject(base_prompt) | |
setting = self.detect_setting(base_prompt) | |
# Build structured prompt | |
components = [] | |
# 1. Article | |
article = "A" if subject[0] not in 'aeiou' else "An" | |
components.append(article) | |
# 2. Descriptive adjectives (max 2-3) | |
adjectives = np.random.choice(self.quality_adjectives, size=2, replace=False) | |
components.extend(adjectives) | |
# 3. Main subject | |
components.append(subject) | |
# 4. Verb/Action (gerund form) | |
if "person" in subject or "man" in subject or "woman" in subject: | |
action = "standing" | |
else: | |
action = "positioned" | |
components.append(action) | |
# 5. Context/Location | |
context_map = { | |
"studio": "in a professional studio setting", | |
"outdoor": "in a natural outdoor environment", | |
"urban": "on an urban street", | |
"coastal": "along a dramatic coastline", | |
"indoor": "in an elegant interior space" | |
} | |
components.append(context_map.get(setting, "in a carefully composed scene")) | |
# 6. Environmental details | |
env_details = ["with subtle atmospheric effects", "surrounded by carefully balanced elements"] | |
components.append(np.random.choice(env_details)) | |
# 7. Materials/Textures (if applicable) | |
if any(mat in base_prompt.lower() for mat in ["car", "vehicle", "metal"]): | |
material = np.random.choice(["with metallic surfaces", "featuring chrome details"]) | |
components.append(material) | |
# 8. Lighting effects | |
lighting = np.random.choice(self.lighting_types) | |
components.append(f"illuminated by {lighting}") | |
# 9. Technical specs | |
tech_spec = np.random.choice(self.technical_specs) | |
components.append(tech_spec) | |
# 10. Quality/Style | |
if style_preference == "cinematic": | |
quality = "cinematic composition" | |
elif style_preference == "commercial": | |
quality = "commercial photography quality" | |
else: | |
quality = "professional photography" | |
components.append(quality) | |
# Join components with proper punctuation | |
prompt = ", ".join(components) | |
# Capitalize first letter | |
prompt = prompt[0].upper() + prompt[1:] | |
return prompt | |
def get_optimization_score(self, prompt): | |
"""Calculate optimization score for Flux compatibility""" | |
score = 0 | |
max_score = 100 | |
# Structure check (order compliance) | |
if prompt.startswith(("A", "An", "The")): | |
score += 15 | |
# Adjective count (optimal 2-3) | |
adj_count = len([adj for adj in self.quality_adjectives if adj in prompt.lower()]) | |
if 2 <= adj_count <= 3: | |
score += 15 | |
elif adj_count == 1: | |
score += 10 | |
# Technical specs presence | |
if any(spec in prompt for spec in self.technical_specs): | |
score += 20 | |
# Lighting specification | |
if any(light in prompt.lower() for light in self.lighting_types): | |
score += 15 | |
# No forbidden elements | |
if not any(forbidden in prompt for forbidden in self.forbidden_elements): | |
score += 15 | |
# Proper punctuation and structure | |
if "," in prompt and prompt.endswith(("photography", "composition", "quality")): | |
score += 10 | |
# Length optimization (Flux works best with detailed but not excessive prompts) | |
word_count = len(prompt.split()) | |
if 15 <= word_count <= 35: | |
score += 10 | |
elif 10 <= word_count <= 45: | |
score += 5 | |
return min(score, max_score) | |
class FluxPromptOptimizer: | |
def __init__(self): | |
self.interrogator = None | |
self.flux_engine = FluxRulesEngine() | |
self.usage_count = 0 | |
self.device = DEVICE | |
self.is_initialized = False | |
def initialize_model(self, progress_callback=None): | |
if self.is_initialized: | |
return True | |
try: | |
if progress_callback: | |
progress_callback("Initializing CLIP model...") | |
config = Config( | |
clip_model_name="ViT-L-14/openai", | |
download_cache=True, | |
chunk_size=2048, | |
quiet=True, | |
device=self.device | |
) | |
self.interrogator = Interrogator(config) | |
self.is_initialized = True | |
if self.device == "cpu": | |
gc.collect() | |
else: | |
torch.cuda.empty_cache() | |
return True | |
except Exception as e: | |
logger.error(f"Initialization error: {e}") | |
return False | |
def optimize_image(self, image): | |
if image is None: | |
return None | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
elif not isinstance(image, Image.Image): | |
image = Image.open(image) | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Optimize image size for processing | |
max_size = 768 if self.device != "cpu" else 512 | |
if image.size[0] > max_size or image.size[1] > max_size: | |
image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS) | |
return image | |
def generate_optimized_prompt(self, image, style_preference="professional", mode="best", progress_callback=None): | |
try: | |
if not self.is_initialized: | |
if not self.initialize_model(progress_callback): | |
return "β Model initialization failed.", "", 0 | |
if image is None: | |
return "β Please upload an image.", "", 0 | |
self.usage_count += 1 | |
if progress_callback: | |
progress_callback("Analyzing image content...") | |
image = self.optimize_image(image) | |
if image is None: | |
return "β Image processing failed.", "", 0 | |
if progress_callback: | |
progress_callback("Extracting visual features...") | |
start_time = datetime.now() | |
# Get base analysis from CLIP | |
try: | |
if mode == "fast": | |
base_prompt = self.interrogator.interrogate_fast(image) | |
elif mode == "classic": | |
base_prompt = self.interrogator.interrogate_classic(image) | |
else: | |
base_prompt = self.interrogator.interrogate(image) | |
except Exception as e: | |
base_prompt = self.interrogator.interrogate_fast(image) | |
if progress_callback: | |
progress_callback("Applying Flux optimization rules...") | |
# Apply Flux-specific optimization | |
optimized_prompt = self.flux_engine.optimize_for_flux(base_prompt, style_preference) | |
# Calculate optimization score | |
score = self.flux_engine.get_optimization_score(optimized_prompt) | |
end_time = datetime.now() | |
duration = (end_time - start_time).total_seconds() | |
# Memory cleanup | |
if self.device == "cpu": | |
gc.collect() | |
else: | |
torch.cuda.empty_cache() | |
# Generate analysis info | |
gpu_status = "β‘ ZeroGPU" if torch.cuda.is_available() else "π» CPU" | |
analysis_info = f""" | |
**Analysis Complete** | |
**Processing:** {gpu_status} β’ {duration:.1f}s β’ {mode.title()} mode | |
**Style:** {style_preference.title()} photography | |
**Optimization Score:** {score}/100 | |
**Generation:** #{self.usage_count} | |
**Base Analysis:** {base_prompt[:100]}... | |
**Enhancement:** Applied Flux-specific structure and terminology | |
""" | |
return optimized_prompt, analysis_info, score | |
except Exception as e: | |
return f"β Error: {str(e)}", "Please try with a different image or contact support.", 0 | |
optimizer = FluxPromptOptimizer() | |
def process_image_with_progress(image, style_preference, mode): | |
def progress_callback(message): | |
return message | |
yield "π Initializing Flux Optimizer...", """ | |
**Flux Prompt Optimizer** | |
Analyzing image with advanced computer vision | |
Applying research-based optimization rules | |
Generating Flux-compatible prompt structure | |
""", 0 | |
prompt, info, score = optimizer.generate_optimized_prompt(image, style_preference, mode, progress_callback) | |
yield prompt, info, score | |
def clear_outputs(): | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return "", "", 0 | |
def create_interface(): | |
# Professional CSS with elegant typography | |
css = """ | |
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
.gradio-container { | |
max-width: 1200px !important; | |
margin: 0 auto !important; | |
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important; | |
background: linear-gradient(135deg, #f8fafc 0%, #f1f5f9 100%) !important; | |
} | |
.main-header { | |
text-align: center; | |
padding: 2rem 0 3rem 0; | |
background: linear-gradient(135deg, #1e293b 0%, #334155 100%); | |
color: white; | |
margin: -2rem -2rem 2rem -2rem; | |
border-radius: 0 0 24px 24px; | |
} | |
.main-title { | |
font-size: 2.5rem !important; | |
font-weight: 700 !important; | |
margin: 0 0 0.5rem 0 !important; | |
letter-spacing: -0.025em !important; | |
background: linear-gradient(135deg, #60a5fa 0%, #3b82f6 100%); | |
-webkit-background-clip: text; | |
-webkit-text-fill-color: transparent; | |
background-clip: text; | |
} | |
.subtitle { | |
font-size: 1.125rem !important; | |
font-weight: 400 !important; | |
opacity: 0.8 !important; | |
margin: 0 !important; | |
} | |
.section-header { | |
font-size: 1.25rem !important; | |
font-weight: 600 !important; | |
color: #1e293b !important; | |
margin: 0 0 1rem 0 !important; | |
padding-bottom: 0.5rem !important; | |
border-bottom: 2px solid #e2e8f0 !important; | |
} | |
.prompt-output { | |
font-family: 'SF Mono', 'Monaco', 'Inconsolata', 'Roboto Mono', monospace !important; | |
font-size: 14px !important; | |
line-height: 1.6 !important; | |
background: linear-gradient(135deg, #ffffff 0%, #f8fafc 100%) !important; | |
border: 1px solid #e2e8f0 !important; | |
border-radius: 12px !important; | |
padding: 1.5rem !important; | |
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1) !important; | |
} | |
.info-panel { | |
background: linear-gradient(135deg, #f0f9ff 0%, #e0f2fe 100%) !important; | |
border: 1px solid #0ea5e9 !important; | |
border-radius: 12px !important; | |
padding: 1.25rem !important; | |
font-size: 0.875rem !important; | |
line-height: 1.5 !important; | |
} | |
.score-display { | |
text-align: center !important; | |
padding: 1rem !important; | |
background: linear-gradient(135deg, #f0fdf4 0%, #dcfce7 100%) !important; | |
border: 2px solid #22c55e !important; | |
border-radius: 12px !important; | |
margin: 1rem 0 !important; | |
} | |
.score-number { | |
font-size: 2rem !important; | |
font-weight: 700 !important; | |
color: #16a34a !important; | |
margin: 0 !important; | |
} | |
.score-label { | |
font-size: 0.875rem !important; | |
color: #15803d !important; | |
margin: 0 !important; | |
text-transform: uppercase !important; | |
letter-spacing: 0.05em !important; | |
} | |
""" | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
title="Flux Prompt Optimizer", | |
css=css | |
) as interface: | |
gr.HTML(""" | |
<div class="main-header"> | |
<div class="main-title">β‘ Flux Prompt Optimizer</div> | |
<div class="subtitle">Advanced prompt generation for Flux models β’ Research-based optimization</div> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("## π· Image Input", elem_classes=["section-header"]) | |
image_input = gr.Image( | |
label="Upload your image", | |
type="pil", | |
height=320, | |
show_label=False | |
) | |
gr.Markdown("## βοΈ Optimization Settings", elem_classes=["section-header"]) | |
style_selector = gr.Dropdown( | |
choices=["professional", "cinematic", "commercial", "artistic"], | |
value="professional", | |
label="Photography Style", | |
info="Select the target style for prompt optimization" | |
) | |
mode_selector = gr.Dropdown( | |
choices=["fast", "classic", "best"], | |
value="best", | |
label="Analysis Mode", | |
info="Balance between speed and detail" | |
) | |
optimize_btn = gr.Button( | |
"π Generate Optimized Prompt", | |
variant="primary", | |
size="lg" | |
) | |
gr.Markdown(""" | |
### About Flux Optimization | |
This tool applies research-validated rules for Flux prompt generation: | |
β’ **Structured composition** following optimal element order | |
β’ **Technical specifications** for professional results | |
β’ **Lighting and material** terminology optimization | |
β’ **Quality markers** specific to Flux model architecture | |
""") | |
with gr.Column(scale=1): | |
gr.Markdown("## π Optimized Prompt", elem_classes=["section-header"]) | |
prompt_output = gr.Textbox( | |
label="Generated Prompt", | |
placeholder="Your optimized Flux prompt will appear here...", | |
lines=6, | |
max_lines=10, | |
elem_classes=["prompt-output"], | |
show_copy_button=True, | |
show_label=False | |
) | |
# Score display | |
score_output = gr.HTML( | |
value='<div class="score-display"><div class="score-number">--</div><div class="score-label">Optimization Score</div></div>' | |
) | |
info_output = gr.Markdown( | |
value="", | |
elem_classes=["info-panel"] | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("ποΈ Clear", size="sm") | |
copy_btn = gr.Button("π Copy Prompt", size="sm") | |
gr.Markdown(""" | |
--- | |
### π¬ Research Foundation | |
Flux Prompt Optimizer implements validated prompt engineering research for optimal Flux model performance. | |
The optimization engine applies structured composition rules, technical terminology, and quality markers | |
specifically calibrated for Flux architecture. | |
**Developed by Pariente AI** β’ Advanced AI Research Laboratory | |
""") | |
# Event handlers | |
def update_score_display(score): | |
color = "#22c55e" if score >= 80 else "#f59e0b" if score >= 60 else "#ef4444" | |
return f''' | |
<div class="score-display" style="border-color: {color};"> | |
<div class="score-number" style="color: {color};">{score}</div> | |
<div class="score-label">Optimization Score</div> | |
</div> | |
''' | |
def copy_prompt_to_clipboard(prompt): | |
return prompt | |
optimize_btn.click( | |
fn=lambda img, style, mode: [ | |
*process_image_with_progress(img, style, mode), | |
update_score_display(list(process_image_with_progress(img, style, mode))[-1][2]) | |
], | |
inputs=[image_input, style_selector, mode_selector], | |
outputs=[prompt_output, info_output, score_output] | |
) | |
clear_btn.click( | |
fn=clear_outputs, | |
outputs=[prompt_output, info_output, score_output] | |
) | |
copy_btn.click( | |
fn=copy_prompt_to_clipboard, | |
inputs=[prompt_output], | |
outputs=[] | |
) | |
return interface | |
if __name__ == "__main__": | |
logger.info("π Starting Flux Prompt Optimizer") | |
interface = create_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |