Spaces:
Sleeping
Sleeping
| import io | |
| import random | |
| import time | |
| from pathlib import Path | |
| import modal | |
| MINUTES = 60 | |
| app = modal.App("example-text-to-image") | |
| CACHE_DIR = "/cache" | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.12") | |
| .pip_install( | |
| "accelerate==0.33.0", | |
| "diffusers==0.31.0", | |
| "fastapi[standard]==0.115.4", | |
| "huggingface-hub[hf_transfer]==0.25.2", | |
| "sentencepiece==0.2.0", | |
| "torch==2.5.1", | |
| "torchvision==0.20.1", | |
| "transformers~=4.44.0", | |
| ) | |
| .env( | |
| { | |
| "HF_HUB_ENABLE_HF_TRANSFER": "1", # faster downloads | |
| "HF_HUB_CACHE_DIR": CACHE_DIR, | |
| } | |
| ) | |
| ) | |
| with image.imports(): | |
| import diffusers | |
| import torch | |
| from fastapi import Response | |
| MODEL_ID = "adamo1139/stable-diffusion-3.5-large-turbo-ungated" | |
| MODEL_REVISION_ID = "9ad870ac0b0e5e48ced156bb02f85d324b7275d2" | |
| cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True) | |
| class Inference: | |
| def load_pipeline(self): | |
| self.pipe = diffusers.StableDiffusion3Pipeline.from_pretrained( | |
| MODEL_ID, | |
| revision=MODEL_REVISION_ID, | |
| torch_dtype=torch.bfloat16, | |
| ).to("cuda") | |
| def run( | |
| self, prompt: str, batch_size: int = 4, seed: int = None | |
| ) -> list[bytes]: | |
| seed = seed if seed is not None else random.randint(0, 2**32 - 1) | |
| print("seeding RNG with", seed) | |
| torch.manual_seed(seed) | |
| images = self.pipe( | |
| prompt, | |
| num_images_per_prompt=batch_size, # outputting multiple images per prompt is much cheaper than separate calls | |
| num_inference_steps=4, # turbo is tuned to run in four steps | |
| guidance_scale=0.0, # turbo doesn't use CFG | |
| max_sequence_length=512, # T5-XXL text encoder supports longer sequences, more complex prompts | |
| ).images | |
| image_output = [] | |
| for image in images: | |
| with io.BytesIO() as buf: | |
| image.save(buf, format="PNG") | |
| image_output.append(buf.getvalue()) | |
| torch.cuda.empty_cache() # reduce fragmentation | |
| return image_output | |
| def web(self, prompt: str, seed: int = None): | |
| return Response( | |
| content=self.run.local( # run in the same container | |
| prompt, batch_size=1, seed=seed | |
| )[0], | |
| media_type="image/png", | |
| ) | |
| def entrypoint( | |
| samples: int = 4, | |
| prompt: str = "A princess riding on a pony", | |
| batch_size: int = 4, | |
| seed: int = None, | |
| ): | |
| print( | |
| f"prompt => {prompt}", | |
| f"samples => {samples}", | |
| f"batch_size => {batch_size}", | |
| f"seed => {seed}", | |
| sep="\n", | |
| ) | |
| output_dir = Path("/tmp/stable-diffusion") | |
| output_dir.mkdir(exist_ok=True, parents=True) | |
| inference_service = Inference() | |
| for sample_idx in range(samples): | |
| start = time.time() | |
| images = inference_service.run.remote(prompt, batch_size, seed) | |
| duration = time.time() - start | |
| print(f"Run {sample_idx+1} took {duration:.3f}s") | |
| if sample_idx: | |
| print( | |
| f"\tGenerated {len(images)} image(s) at {(duration)/len(images):.3f}s / image." | |
| ) | |
| for batch_idx, image_bytes in enumerate(images): | |
| output_path = ( | |
| output_dir | |
| / f"output_{slugify(prompt)[:64]}_{str(sample_idx).zfill(2)}_{str(batch_idx).zfill(2)}.png" | |
| ) | |
| if not batch_idx: | |
| print("Saving outputs", end="\n\t") | |
| print( | |
| output_path, | |
| end="\n" + ("\t" if batch_idx < len(images) - 1 else ""), | |
| ) | |
| output_path.write_bytes(image_bytes) | |