Spaces:
Sleeping
Sleeping
import io | |
import random | |
import time | |
from pathlib import Path | |
import modal | |
MINUTES = 60 | |
app = modal.App("example-text-to-image") | |
CACHE_DIR = "/cache" | |
image = ( | |
modal.Image.debian_slim(python_version="3.12") | |
.pip_install( | |
"accelerate==0.33.0", | |
"diffusers==0.31.0", | |
"fastapi[standard]==0.115.4", | |
"huggingface-hub[hf_transfer]==0.25.2", | |
"sentencepiece==0.2.0", | |
"torch==2.5.1", | |
"torchvision==0.20.1", | |
"transformers~=4.44.0", | |
) | |
.env( | |
{ | |
"HF_HUB_ENABLE_HF_TRANSFER": "1", # faster downloads | |
"HF_HUB_CACHE_DIR": CACHE_DIR, | |
} | |
) | |
) | |
with image.imports(): | |
import diffusers | |
import torch | |
from fastapi import Response | |
MODEL_ID = "adamo1139/stable-diffusion-3.5-large-turbo-ungated" | |
MODEL_REVISION_ID = "9ad870ac0b0e5e48ced156bb02f85d324b7275d2" | |
cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True) | |
class Inference: | |
def load_pipeline(self): | |
self.pipe = diffusers.StableDiffusion3Pipeline.from_pretrained( | |
MODEL_ID, | |
revision=MODEL_REVISION_ID, | |
torch_dtype=torch.bfloat16, | |
).to("cuda") | |
def run( | |
self, prompt: str, batch_size: int = 4, seed: int = None | |
) -> list[bytes]: | |
seed = seed if seed is not None else random.randint(0, 2**32 - 1) | |
print("seeding RNG with", seed) | |
torch.manual_seed(seed) | |
images = self.pipe( | |
prompt, | |
num_images_per_prompt=batch_size, # outputting multiple images per prompt is much cheaper than separate calls | |
num_inference_steps=4, # turbo is tuned to run in four steps | |
guidance_scale=0.0, # turbo doesn't use CFG | |
max_sequence_length=512, # T5-XXL text encoder supports longer sequences, more complex prompts | |
).images | |
image_output = [] | |
for image in images: | |
with io.BytesIO() as buf: | |
image.save(buf, format="PNG") | |
image_output.append(buf.getvalue()) | |
torch.cuda.empty_cache() # reduce fragmentation | |
return image_output | |
def web(self, prompt: str, seed: int = None): | |
return Response( | |
content=self.run.local( # run in the same container | |
prompt, batch_size=1, seed=seed | |
)[0], | |
media_type="image/png", | |
) | |
def entrypoint( | |
samples: int = 4, | |
prompt: str = "A princess riding on a pony", | |
batch_size: int = 4, | |
seed: int = None, | |
): | |
print( | |
f"prompt => {prompt}", | |
f"samples => {samples}", | |
f"batch_size => {batch_size}", | |
f"seed => {seed}", | |
sep="\n", | |
) | |
output_dir = Path("/tmp/stable-diffusion") | |
output_dir.mkdir(exist_ok=True, parents=True) | |
inference_service = Inference() | |
for sample_idx in range(samples): | |
start = time.time() | |
images = inference_service.run.remote(prompt, batch_size, seed) | |
duration = time.time() - start | |
print(f"Run {sample_idx+1} took {duration:.3f}s") | |
if sample_idx: | |
print( | |
f"\tGenerated {len(images)} image(s) at {(duration)/len(images):.3f}s / image." | |
) | |
for batch_idx, image_bytes in enumerate(images): | |
output_path = ( | |
output_dir | |
/ f"output_{slugify(prompt)[:64]}_{str(sample_idx).zfill(2)}_{str(batch_idx).zfill(2)}.png" | |
) | |
if not batch_idx: | |
print("Saving outputs", end="\n\t") | |
print( | |
output_path, | |
end="\n" + ("\t" if batch_idx < len(images) - 1 else ""), | |
) | |
output_path.write_bytes(image_bytes) | |