import io import random import time from pathlib import Path import modal MINUTES = 60 app = modal.App("example-text-to-image") CACHE_DIR = "/cache" image = ( modal.Image.debian_slim(python_version="3.12") .pip_install( "accelerate==0.33.0", "diffusers==0.31.0", "fastapi[standard]==0.115.4", "huggingface-hub[hf_transfer]==0.25.2", "sentencepiece==0.2.0", "torch==2.5.1", "torchvision==0.20.1", "transformers~=4.44.0", ) .env( { "HF_HUB_ENABLE_HF_TRANSFER": "1", # faster downloads "HF_HUB_CACHE_DIR": CACHE_DIR, } ) ) with image.imports(): import diffusers import torch from fastapi import Response MODEL_ID = "adamo1139/stable-diffusion-3.5-large-turbo-ungated" MODEL_REVISION_ID = "9ad870ac0b0e5e48ced156bb02f85d324b7275d2" cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True) @app.cls( image=image, gpu="H100", timeout=10 * MINUTES, volumes={CACHE_DIR: cache_volume}, ) class Inference: @modal.enter() def load_pipeline(self): self.pipe = diffusers.StableDiffusion3Pipeline.from_pretrained( MODEL_ID, revision=MODEL_REVISION_ID, torch_dtype=torch.bfloat16, ).to("cuda") @modal.method() def run( self, prompt: str, batch_size: int = 4, seed: int = None ) -> list[bytes]: seed = seed if seed is not None else random.randint(0, 2**32 - 1) print("seeding RNG with", seed) torch.manual_seed(seed) images = self.pipe( prompt, num_images_per_prompt=batch_size, # outputting multiple images per prompt is much cheaper than separate calls num_inference_steps=4, # turbo is tuned to run in four steps guidance_scale=0.0, # turbo doesn't use CFG max_sequence_length=512, # T5-XXL text encoder supports longer sequences, more complex prompts ).images image_output = [] for image in images: with io.BytesIO() as buf: image.save(buf, format="PNG") image_output.append(buf.getvalue()) torch.cuda.empty_cache() # reduce fragmentation return image_output @modal.web_endpoint(docs=True) def web(self, prompt: str, seed: int = None): return Response( content=self.run.local( # run in the same container prompt, batch_size=1, seed=seed )[0], media_type="image/png", ) @app.local_entrypoint() def entrypoint( samples: int = 4, prompt: str = "A princess riding on a pony", batch_size: int = 4, seed: int = None, ): print( f"prompt => {prompt}", f"samples => {samples}", f"batch_size => {batch_size}", f"seed => {seed}", sep="\n", ) output_dir = Path("/tmp/stable-diffusion") output_dir.mkdir(exist_ok=True, parents=True) inference_service = Inference() for sample_idx in range(samples): start = time.time() images = inference_service.run.remote(prompt, batch_size, seed) duration = time.time() - start print(f"Run {sample_idx+1} took {duration:.3f}s") if sample_idx: print( f"\tGenerated {len(images)} image(s) at {(duration)/len(images):.3f}s / image." ) for batch_idx, image_bytes in enumerate(images): output_path = ( output_dir / f"output_{slugify(prompt)[:64]}_{str(sample_idx).zfill(2)}_{str(batch_idx).zfill(2)}.png" ) if not batch_idx: print("Saving outputs", end="\n\t") print( output_path, end="\n" + ("\t" if batch_idx < len(images) - 1 else ""), ) output_path.write_bytes(image_bytes)