File size: 3,723 Bytes
4f48282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import modal
import random
from datetime import datetime
import os
from config.config import models, prompts

# Define the Modal image (same as in modal_app.py)
image = modal.Image.debian_slim(python_version="3.11").pip_install(
    "diffusers",
    "transformers",
    "torch>=2.0.1",
    "accelerate",
    "gradio",
    "safetensors",
    "pillow",
)

# Create a Modal app
app = modal.App("ctb-ai-img-gen-modal", image=image)

# Define a volume for caching models
volume = modal.Volume.from_name("flux-model-vol")

@app.cls(
    gpu="H100",  # Use H100 GPU for maximum performance
    container_idle_timeout=20 * 60,  # 20 minutes
    timeout=60 * 60,  # 1 hour
    volumes={"/cache": volume},
)
class Model:
    def __init__(self):
        self.device = "cuda"
        self.torch_dtype = torch.bfloat16
        self.model_dir = "/cache/models"

    @modal.enter()
    def setup(self):
        import torch
        from diffusers import StableDiffusionPipeline

        # Load the model
        self.pipe = StableDiffusionPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-1",
            torch_dtype=self.torch_dtype,
            safety_checker=None,
            feature_extractor=None,
        ).to(self.device)

        # Optimize the model
        self.pipe = self.optimize(self.pipe)

    def optimize(self, pipe):
        import torch

        # Fuse QKV projections
        pipe.unet.fuse_qkv_projections()
        pipe.vae.fuse_qkv_projections()

        # Switch memory layout
        pipe.unet.to(memory_format=torch.channels_last)
        pipe.vae.to(memory_format=torch.channels_last)

        # Compile the model
        pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True)
        pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)

        return pipe

    @modal.method()
    def generate(self, prompt_alias, team_color, model_alias, custom_prompt):
        import torch
        from diffusers import StableDiffusionPipeline

        # Find the selected prompt and model
        try:
            prompt = next(p for p in prompts if p["alias"] == prompt_alias)["text"]
            model_name = next(m for m in models if m["alias"] == model_alias)["name"]
        except StopIteration:
            return None, "ERROR: Invalid prompt or model selected."

        # Format the prompt
        enemy_color = "blue" if team_color.lower() == "red" else "red"
        prompt = prompt.format(team_color=team_color.lower(), enemy_color=enemy_color)
        if custom_prompt.strip():
            prompt += " " + custom_prompt.strip()

        # Set seed
        seed = random.randint(0, 2**32 - 1)
        torch.manual_seed(seed)

        # Generate the image
        try:
            image = self.pipe(
                prompt,
                guidance_scale=2.0,
                num_inference_steps=20,
                width=640,
                height=360,
                generator=torch.Generator(self.device).manual_seed(seed)
            ).images[0]

            # Save the image
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_filename = f"{timestamp}_{model_alias.replace(' ', '_').lower()}_{prompt_alias.replace(' ', '_').lower()}_{team_color.lower()}.png"
            image.save(output_filename)

            return output_filename, "Image generated successfully!"
        except Exception as e:
            return None, f"ERROR: Failed to generate image. Details: {e}"

# Function to be called from the Gradio interface
def generate(prompt_alias, team_color, model_alias, custom_prompt):
    model = Model()
    return model.generate.remote(prompt_alias, team_color, model_alias, custom_prompt)