Spaces:
Sleeping
Sleeping
File size: 3,936 Bytes
4f48282 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import io
import random
import time
from pathlib import Path
import modal
MINUTES = 60
app = modal.App("example-text-to-image")
CACHE_DIR = "/cache"
image = (
modal.Image.debian_slim(python_version="3.12")
.pip_install(
"accelerate==0.33.0",
"diffusers==0.31.0",
"fastapi[standard]==0.115.4",
"huggingface-hub[hf_transfer]==0.25.2",
"sentencepiece==0.2.0",
"torch==2.5.1",
"torchvision==0.20.1",
"transformers~=4.44.0",
)
.env(
{
"HF_HUB_ENABLE_HF_TRANSFER": "1", # faster downloads
"HF_HUB_CACHE_DIR": CACHE_DIR,
}
)
)
with image.imports():
import diffusers
import torch
from fastapi import Response
MODEL_ID = "adamo1139/stable-diffusion-3.5-large-turbo-ungated"
MODEL_REVISION_ID = "9ad870ac0b0e5e48ced156bb02f85d324b7275d2"
cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True)
@app.cls(
image=image,
gpu="H100",
timeout=10 * MINUTES,
volumes={CACHE_DIR: cache_volume},
)
class Inference:
@modal.enter()
def load_pipeline(self):
self.pipe = diffusers.StableDiffusion3Pipeline.from_pretrained(
MODEL_ID,
revision=MODEL_REVISION_ID,
torch_dtype=torch.bfloat16,
).to("cuda")
@modal.method()
def run(
self, prompt: str, batch_size: int = 4, seed: int = None
) -> list[bytes]:
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
print("seeding RNG with", seed)
torch.manual_seed(seed)
images = self.pipe(
prompt,
num_images_per_prompt=batch_size, # outputting multiple images per prompt is much cheaper than separate calls
num_inference_steps=4, # turbo is tuned to run in four steps
guidance_scale=0.0, # turbo doesn't use CFG
max_sequence_length=512, # T5-XXL text encoder supports longer sequences, more complex prompts
).images
image_output = []
for image in images:
with io.BytesIO() as buf:
image.save(buf, format="PNG")
image_output.append(buf.getvalue())
torch.cuda.empty_cache() # reduce fragmentation
return image_output
@modal.web_endpoint(docs=True)
def web(self, prompt: str, seed: int = None):
return Response(
content=self.run.local( # run in the same container
prompt, batch_size=1, seed=seed
)[0],
media_type="image/png",
)
@app.local_entrypoint()
def entrypoint(
samples: int = 4,
prompt: str = "A princess riding on a pony",
batch_size: int = 4,
seed: int = None,
):
print(
f"prompt => {prompt}",
f"samples => {samples}",
f"batch_size => {batch_size}",
f"seed => {seed}",
sep="\n",
)
output_dir = Path("/tmp/stable-diffusion")
output_dir.mkdir(exist_ok=True, parents=True)
inference_service = Inference()
for sample_idx in range(samples):
start = time.time()
images = inference_service.run.remote(prompt, batch_size, seed)
duration = time.time() - start
print(f"Run {sample_idx+1} took {duration:.3f}s")
if sample_idx:
print(
f"\tGenerated {len(images)} image(s) at {(duration)/len(images):.3f}s / image."
)
for batch_idx, image_bytes in enumerate(images):
output_path = (
output_dir
/ f"output_{slugify(prompt)[:64]}_{str(sample_idx).zfill(2)}_{str(batch_idx).zfill(2)}.png"
)
if not batch_idx:
print("Saving outputs", end="\n\t")
print(
output_path,
end="\n" + ("\t" if batch_idx < len(images) - 1 else ""),
)
output_path.write_bytes(image_bytes)
|