CtB-AI-img-gen / examples /example-text-to-image.py
Andre
update 1.1
4f48282
import io
import random
import time
from pathlib import Path
import modal
MINUTES = 60
app = modal.App("example-text-to-image")
CACHE_DIR = "/cache"
image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"accelerate==0.33.0",
"diffusers==0.31.0",
"fastapi[standard]==0.115.4",
"huggingface-hub[hf_transfer]==0.25.2",
"sentencepiece==0.2.0",
"torch==2.5.1",
"torchvision==0.20.1",
"transformers~=4.44.0",
)
.env(
{
"HF_HUB_ENABLE_HF_TRANSFER": "1", # faster downloads
"HF_HUB_CACHE_DIR": CACHE_DIR,
}
)
)
with image.imports():
import diffusers
import torch
from fastapi import Response
MODEL_ID = "adamo1139/stable-diffusion-3.5-large-turbo-ungated"
MODEL_REVISION_ID = "9ad870ac0b0e5e48ced156bb02f85d324b7275d2"
cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True)
@app.cls(
image=image,
gpu="H100",
timeout=10 * MINUTES,
volumes={CACHE_DIR: cache_volume},
)
class Inference:
@modal.enter()
def load_pipeline(self):
self.pipe = diffusers.StableDiffusion3Pipeline.from_pretrained(
MODEL_ID,
revision=MODEL_REVISION_ID,
torch_dtype=torch.bfloat16,
).to("cuda")
@modal.method()
def run(
self, prompt: str, batch_size: int = 4, seed: int = None
) -> list[bytes]:
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
print("seeding RNG with", seed)
torch.manual_seed(seed)
images = self.pipe(
prompt,
num_images_per_prompt=batch_size, # outputting multiple images per prompt is much cheaper than separate calls
num_inference_steps=4, # turbo is tuned to run in four steps
guidance_scale=0.0, # turbo doesn't use CFG
max_sequence_length=512, # T5-XXL text encoder supports longer sequences, more complex prompts
).images
image_output = []
for image in images:
with io.BytesIO() as buf:
image.save(buf, format="PNG")
image_output.append(buf.getvalue())
torch.cuda.empty_cache() # reduce fragmentation
return image_output
@modal.web_endpoint(docs=True)
def web(self, prompt: str, seed: int = None):
return Response(
content=self.run.local( # run in the same container
prompt, batch_size=1, seed=seed
)[0],
media_type="image/png",
)
@app.local_entrypoint()
def entrypoint(
samples: int = 4,
prompt: str = "A princess riding on a pony",
batch_size: int = 4,
seed: int = None,
):
print(
f"prompt => {prompt}",
f"samples => {samples}",
f"batch_size => {batch_size}",
f"seed => {seed}",
sep="\n",
)
output_dir = Path("/tmp/stable-diffusion")
output_dir.mkdir(exist_ok=True, parents=True)
inference_service = Inference()
for sample_idx in range(samples):
start = time.time()
images = inference_service.run.remote(prompt, batch_size, seed)
duration = time.time() - start
print(f"Run {sample_idx+1} took {duration:.3f}s")
if sample_idx:
print(
f"\tGenerated {len(images)} image(s) at {(duration)/len(images):.3f}s / image."
)
for batch_idx, image_bytes in enumerate(images):
output_path = (
output_dir
/ f"output_{slugify(prompt)[:64]}_{str(sample_idx).zfill(2)}_{str(batch_idx).zfill(2)}.png"
)
if not batch_idx:
print("Saving outputs", end="\n\t")
print(
output_path,
end="\n" + ("\t" if batch_idx < len(images) - 1 else ""),
)
output_path.write_bytes(image_bytes)