Spaces:
Sleeping
Sleeping
File size: 5,254 Bytes
4f48282 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import modal
import random
from datetime import datetime
import io
import os
from config.config import models, prompts
volume = modal.Volume.from_name("flux-model-vol")
# Define the Modal image
image = (modal.Image.debian_slim(python_version="3.9")
.pip_install(
"ninja",
"packaging",
"wheel",
"diffusers", # For Stable Diffusion
"transformers", # For Hugging Face models
"torch>=2.0.1", # PyTorch with a minimum version
"accelerate", # For distributed training/inference
"gradio", # For the Gradio interface
"safetensors", # For safe model loading
"pillow", # For image processing
"datasets", # For datasets (if needed)
)
)
app = modal.App("ctb-ai-img-gen-modal", image=image)
class Inference:
@modal.enter()
def load_pipeline(self):
import torch
from diffusers import StableDiffusionPipeline
# Debug function to check installed packages
def check_dependencies():
import importlib
packages = [
"diffusers", # For Stable Diffusion
"transformers", # For Hugging Face models
"torch", # PyTorch
"accelerate", # For distributed training/inference
"gradio>=4.44.1", # For the Gradio interface (updated to latest version)
"safetensors", # For safe model loading
"pillow", # For image processing
]
for package in packages:
try:
module = importlib.import_module(package)
print(f"✅ {package} is installed. Version: {module.__version__}")
except ImportError:
print(f"❌ {package} is NOT installed.")
# Check if the directory exists
model_dir = "/volume/FLUX.1-dev"
if not os.path.exists(model_dir):
raise FileNotFoundError(f"Model directory not found at {model_dir}")
print(f"Model directory found at {model_dir}! Proceeding with image generation...")
print("Contents of FLUX.1-dev:")
print(os.listdir(model_dir))
# Load the pipeline
self.model_dir = model_dir
self.device = "cuda"
self.torch_dtype = torch.float16
@modal.method()
def run(
self,
prompt_alias: str,
team_color: str,
model_alias: str,
custom_prompt: str,
height: int = 360,
width: int = 640,
num_inference_steps: int = 20,
guidance_scale: float = 2.0,
seed: int = -1,
) -> tuple[str, str]:
import torch
from diffusers import StableDiffusionPipeline
# Find the selected prompt and model
try:
prompt = next(p for p in prompts if p["alias"] == prompt_alias)["text"]
model_name = next(m for m in models if m["alias"] == model_alias)["name"]
except StopIteration:
return None, "ERROR: Invalid prompt or model selected."
# Determine the enemy color
enemy_color = "blue" if team_color.lower() == "red" else "red"
# Format the prompt
prompt = prompt.format(team_color=team_color.lower(), enemy_color=enemy_color)
# Append custom prompt if provided
if custom_prompt and len(custom_prompt.strip()) > 0:
prompt += " " + custom_prompt.strip()
# Set seed
seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
print("seeding RNG with", seed)
torch.manual_seed(seed)
# Load the pipeline
model_path = os.path.join(self.model_dir, model_name)
pipe = StableDiffusionPipeline.from_pretrained(
model_path,
torch_dtype=self.torch_dtype,
safety_checker=None, # Disable safety checker
feature_extractor=None, # Disable feature extractor
).to(self.device)
# Generate the image
try:
image = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=torch.Generator(self.device).manual_seed(seed)
).images[0]
# Save the image with a timestamped filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"{timestamp}_{model_alias.replace(' ', '_').lower()}_{prompt_alias.replace(' ', '_').lower()}_{team_color.lower()}.png"
image.save(output_filename)
return output_filename, "Image generated successfully!"
except Exception as e:
return None, f"ERROR: Failed to generate image. Details: {e}"
# Function to be called from the Gradio interface
def generate(prompt_alias, team_color, model_alias, custom_prompt, height=360, width=640, num_inference_steps=20, guidance_scale=2.0, seed=-1):
from src.img_gen_modal import Inference
try:
# Generate the image
image_path, message = Inference(prompt_alias, team_color, model_alias, custom_prompt, height, width, num_inference_steps, guidance_scale, seed)
return image_path, message
except Exception as e:
return None, f"An error occurred: {e}"
|