import modal import random from datetime import datetime import io import os from config.config import models, prompts volume = modal.Volume.from_name("flux-model-vol") # Define the Modal image image = (modal.Image.debian_slim(python_version="3.9") .pip_install( "ninja", "packaging", "wheel", "diffusers", # For Stable Diffusion "transformers", # For Hugging Face models "torch>=2.0.1", # PyTorch with a minimum version "accelerate", # For distributed training/inference "gradio", # For the Gradio interface "safetensors", # For safe model loading "pillow", # For image processing "datasets", # For datasets (if needed) ) ) app = modal.App("ctb-ai-img-gen-modal", image=image) class Inference: @modal.enter() def load_pipeline(self): import torch from diffusers import StableDiffusionPipeline # Debug function to check installed packages def check_dependencies(): import importlib packages = [ "diffusers", # For Stable Diffusion "transformers", # For Hugging Face models "torch", # PyTorch "accelerate", # For distributed training/inference "gradio>=4.44.1", # For the Gradio interface (updated to latest version) "safetensors", # For safe model loading "pillow", # For image processing ] for package in packages: try: module = importlib.import_module(package) print(f"✅ {package} is installed. Version: {module.__version__}") except ImportError: print(f"❌ {package} is NOT installed.") # Check if the directory exists model_dir = "/volume/FLUX.1-dev" if not os.path.exists(model_dir): raise FileNotFoundError(f"Model directory not found at {model_dir}") print(f"Model directory found at {model_dir}! Proceeding with image generation...") print("Contents of FLUX.1-dev:") print(os.listdir(model_dir)) # Load the pipeline self.model_dir = model_dir self.device = "cuda" self.torch_dtype = torch.float16 @modal.method() def run( self, prompt_alias: str, team_color: str, model_alias: str, custom_prompt: str, height: int = 360, width: int = 640, num_inference_steps: int = 20, guidance_scale: float = 2.0, seed: int = -1, ) -> tuple[str, str]: import torch from diffusers import StableDiffusionPipeline # Find the selected prompt and model try: prompt = next(p for p in prompts if p["alias"] == prompt_alias)["text"] model_name = next(m for m in models if m["alias"] == model_alias)["name"] except StopIteration: return None, "ERROR: Invalid prompt or model selected." # Determine the enemy color enemy_color = "blue" if team_color.lower() == "red" else "red" # Format the prompt prompt = prompt.format(team_color=team_color.lower(), enemy_color=enemy_color) # Append custom prompt if provided if custom_prompt and len(custom_prompt.strip()) > 0: prompt += " " + custom_prompt.strip() # Set seed seed = seed if seed != -1 else random.randint(0, 2**32 - 1) print("seeding RNG with", seed) torch.manual_seed(seed) # Load the pipeline model_path = os.path.join(self.model_dir, model_name) pipe = StableDiffusionPipeline.from_pretrained( model_path, torch_dtype=self.torch_dtype, safety_checker=None, # Disable safety checker feature_extractor=None, # Disable feature extractor ).to(self.device) # Generate the image try: image = pipe( prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=torch.Generator(self.device).manual_seed(seed) ).images[0] # Save the image with a timestamped filename timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_filename = f"{timestamp}_{model_alias.replace(' ', '_').lower()}_{prompt_alias.replace(' ', '_').lower()}_{team_color.lower()}.png" image.save(output_filename) return output_filename, "Image generated successfully!" except Exception as e: return None, f"ERROR: Failed to generate image. Details: {e}" # Function to be called from the Gradio interface def generate(prompt_alias, team_color, model_alias, custom_prompt, height=360, width=640, num_inference_steps=20, guidance_scale=2.0, seed=-1): from src.img_gen_modal import Inference try: # Generate the image image_path, message = Inference(prompt_alias, team_color, model_alias, custom_prompt, height, width, num_inference_steps, guidance_scale, seed) return image_path, message except Exception as e: return None, f"An error occurred: {e}"