Spaces:
Sleeping
Sleeping
import modal | |
import random | |
from datetime import datetime | |
import os | |
from config.config import models, prompts | |
# Define the Modal image (same as in modal_app.py) | |
image = modal.Image.debian_slim(python_version="3.11").pip_install( | |
"diffusers", | |
"transformers", | |
"torch>=2.0.1", | |
"accelerate", | |
"gradio", | |
"safetensors", | |
"pillow", | |
) | |
# Create a Modal app | |
app = modal.App("ctb-ai-img-gen-modal", image=image) | |
# Define a volume for caching models | |
volume = modal.Volume.from_name("flux-model-vol") | |
class Model: | |
def __init__(self): | |
self.device = "cuda" | |
self.torch_dtype = torch.bfloat16 | |
self.model_dir = "/cache/models" | |
def setup(self): | |
import torch | |
from diffusers import StableDiffusionPipeline | |
# Load the model | |
self.pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-1", | |
torch_dtype=self.torch_dtype, | |
safety_checker=None, | |
feature_extractor=None, | |
).to(self.device) | |
# Optimize the model | |
self.pipe = self.optimize(self.pipe) | |
def optimize(self, pipe): | |
import torch | |
# Fuse QKV projections | |
pipe.unet.fuse_qkv_projections() | |
pipe.vae.fuse_qkv_projections() | |
# Switch memory layout | |
pipe.unet.to(memory_format=torch.channels_last) | |
pipe.vae.to(memory_format=torch.channels_last) | |
# Compile the model | |
pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True) | |
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True) | |
return pipe | |
def generate(self, prompt_alias, team_color, model_alias, custom_prompt): | |
import torch | |
from diffusers import StableDiffusionPipeline | |
# Find the selected prompt and model | |
try: | |
prompt = next(p for p in prompts if p["alias"] == prompt_alias)["text"] | |
model_name = next(m for m in models if m["alias"] == model_alias)["name"] | |
except StopIteration: | |
return None, "ERROR: Invalid prompt or model selected." | |
# Format the prompt | |
enemy_color = "blue" if team_color.lower() == "red" else "red" | |
prompt = prompt.format(team_color=team_color.lower(), enemy_color=enemy_color) | |
if custom_prompt.strip(): | |
prompt += " " + custom_prompt.strip() | |
# Set seed | |
seed = random.randint(0, 2**32 - 1) | |
torch.manual_seed(seed) | |
# Generate the image | |
try: | |
image = self.pipe( | |
prompt, | |
guidance_scale=2.0, | |
num_inference_steps=20, | |
width=640, | |
height=360, | |
generator=torch.Generator(self.device).manual_seed(seed) | |
).images[0] | |
# Save the image | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
output_filename = f"{timestamp}_{model_alias.replace(' ', '_').lower()}_{prompt_alias.replace(' ', '_').lower()}_{team_color.lower()}.png" | |
image.save(output_filename) | |
return output_filename, "Image generated successfully!" | |
except Exception as e: | |
return None, f"ERROR: Failed to generate image. Details: {e}" | |
# Function to be called from the Gradio interface | |
def generate(prompt_alias, team_color, model_alias, custom_prompt): | |
model = Model() | |
return model.generate.remote(prompt_alias, team_color, model_alias, custom_prompt) |