Spaces:
Runtime error
ๅนๆ็ใงๅน็็ใชๆกๆฃใขใใซ
[[open-in-colab]]
[DiffusionPipeline
]ใไฝฟใฃใฆ็นๅฎใฎในใฟใคใซใง็ปๅใ็ๆใใใใๅธๆใใ็ปๅใ็ๆใใใใใใฎใฏ้ฃใใใใจใงใใๅคใใฎๅ ดๅใ[DiffusionPipeline
]ใไฝๅบฆใๅฎ่กใใฆใใใงใชใใจๆบ่ถณใฎใใ็ปๅใฏๅพใใใพใใใใใใใไฝใใชใใจใใใใไฝใใ็ๆใใใซใฏใใใใใฎ่จ็ฎใๅฟ
่ฆใงใใ็ๆใไฝๅบฆใไฝๅบฆใๅฎ่กใใๅ ดๅใ็นใซใใใใใฎ่จ็ฎ้ใๅฟ
่ฆใซใชใใพใใ
ใใฎใใใใใคใใฉใคใณใใ่จ็ฎ๏ผ้ๅบฆ๏ผใจใกใขใช๏ผGPU RAM๏ผใฎๅน็ใๆๅคง้ใซๅผใๅบใใ็ๆใตใคใฏใซ้ใฎๆ้ใ็ญ็ธฎใใใใจใงใใใ้ซ้ใชๅๅพฉๅฆ็ใ่กใใใใใซใใใใจใ้่ฆใงใใ
ใใฎใใฅใผใใชใขใซใงใฏใ[DiffusionPipeline
]ใ็จใใฆใใใ้ใใใใ่ฏใ่จ็ฎใ่กใๆนๆณใ่ชฌๆใใพใใ
ใพใใrunwayml/stable-diffusion-v1-5
ใขใใซใใญใผใใใพใ๏ผ
from diffusers import DiffusionPipeline
model_id = "runwayml/stable-diffusion-v1-5"
pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
ใใใงไฝฟ็จใใใใญใณใใใฎไพใฏๅนด่ใใๆฆๅฃซใฎ้ทใฎ่ๅ็ปใงใใใใ่ช็ฑใซๅคๆดใใฆใใ ใใ๏ผ
prompt = "portrait photo of a old warrior chief"
Speed
๐ก GPUใๅฉ็จใงใใชใๅ ดๅใฏใColabใฎใใใชGPUใใญใใคใใผใใ็กๆใงๅฉ็จใงใใพใ๏ผ
็ปๅ็ๆใ้ซ้ๅใใๆใ็ฐกๅใชๆนๆณใฎ1ใคใฏใPyTorchใขใธใฅใผใซใจๅใใใใซGPUไธใซใใคใใฉใคใณใ้ ็ฝฎใใใใจใงใ๏ผ
pipeline = pipeline.to("cuda")
ๅใใคใกใผใธใไฝฟใฃใฆๆน่ฏใงใใใใใซใใใซใฏใGenerator
ใไฝฟใใreproducibilityใฎ็จฎใ่จญๅฎใใพใ๏ผ
import torch
generator = torch.Generator("cuda").manual_seed(0)
ใใใง็ปๅใ็ๆใงใใพใ๏ผ
image = pipeline(prompt, generator=generator).images[0]
image

ใใฎๅฆ็ใซใฏT4 GPUใง~30็งใใใใพใใ๏ผๅฒใๅฝใฆใใใฆใใGPUใT4ใใๅชใใฆใใๅ ดๅใฏใใฃใจ้ใใใใใใพใใ๏ผใใใใฉใซใใงใฏใ[DiffusionPipeline
]ใฏๅฎๅ
จใชfloat32
็ฒพๅบฆใง็ๆใ50ในใใใๅฎ่กใใพใใfloat16`ใฎใใใชไฝใ็ฒพๅบฆใซๅคๆดใใใใๆจ่ซในใใใๆฐใๆธใใใใจใง้ซ้ๅใใใใจใใงใใพใใ
ใพใใฏ float16
ใงใขใใซใใญใผใใใฆ็ปๅใ็ๆใใฆใฟใพใใใ๏ผ
import torch
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
pipeline = pipeline.to("cuda")
generator = torch.Generator("cuda").manual_seed(0)
image = pipeline(prompt, generator=generator).images[0]
image

ไปๅใ็ปๅ็ๆใซใใใฃใๆ้ใฏใใใ11็งใงใไปฅๅใใ3ๅ่ฟใ้ใใชใใพใใ๏ผ
๐ก ใใคใใฉใคใณใฏๅธธใซ float16
ใงๅฎ่กใใใใจใๅผทใใๅงใใใพใใ
็ๆในใใใๆฐใๆธใใใจใใๆนๆณใใใใพใใใใๅน็็ใชในใฑใธใฅใผใฉใ้ธๆใใใใจใงใๅบๅๅ่ณชใ็ ็ฒใซใใใใจใชใในใใใๆฐใๆธใใใใจใใงใใพใใcompatibles
ใกใฝใใใๅผใณๅบใใใจใงใ[DiffusionPipeline
]ใฎ็พๅจใฎใขใใซใจไบๆๆงใฎใใในใฑใธใฅใผใฉใ่ฆใคใใใใจใใงใใพใ๏ผ
pipeline.scheduler.compatibles
[
diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
]
Stable Diffusionใขใใซใฏใใใฉใซใใง[PNDMScheduler
]ใไฝฟ็จใใพใใใใฎในใฑใธใฅใผใฉใฏ้ๅธธ50ใฎๆจ่ซในใใใใๅฟ
่ฆใจใใพใใใ[20ใพใใฏ25ใฎๆจ่ซในใใใใงๆธใฟใพใใ[DPMSolverMultistepScheduler
]ใฎใใใช้ซๆง่ฝใชในใฑใธใฅใผใฉใงใฏConfigMixin.from_config
]ใกใฝใใใไฝฟ็จใใใจใๆฐใใในใฑใธใฅใผใฉใใญใผใใใใใจใใงใใพใ๏ผ
from diffusers import DPMSolverMultistepScheduler
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
ใใใง num_inference_steps
ใ20ใซ่จญๅฎใใพใ๏ผ
generator = torch.Generator("cuda").manual_seed(0)
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
image

ๆจ่ซๆ้ใใใใ4็งใซ็ญ็ธฎใใใใจใซๆๅใใ๏ผโก๏ธ
ใกใขใชใผ
ใใคใใฉใคใณใฎใใใฉใผใใณในใๅไธใใใใใ1ใคใฎ้ตใฏใๆถ่ฒปใกใขใชใๅฐใชใใใใใจใงใใไธๅบฆใซ็ๆใงใใ็ปๅใฎๆฐใ็ขบ่ชใใๆใ็ฐกๅใชๆนๆณใฏใOutOfMemoryError
๏ผOOM๏ผใ็บ็ใใใพใงใใใพใใพใชใใใใตใคใบใ่ฉฆใใฆใฟใใใจใงใใ
ๆ็ซ ใจ Generators
ใฎใชในใใใ็ปๅใฎใใใใ็ๆใใ้ขๆฐใไฝๆใใพใใๅ Generator
ใซใทใผใใๅฒใๅฝใฆใฆใ่ฏใ็ตๆใๅพใใใๅ ดๅใซๅๅฉ็จใงใใใใใซใใพใใ
def get_inputs(batch_size=1):
generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
prompts = batch_size * [prompt]
num_inference_steps = 20
return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
batch_size=4
ใง้ๅงใใใฉใใ ใใกใขใชใๆถ่ฒปใใใใ็ขบ่ชใใพใ๏ผ
from diffusers.utils import make_image_grid
images = pipeline(**get_inputs(batch_size=4)).images
make_image_grid(images, 2, 2)
ๅคงๅฎน้ใฎRAMใๆญ่ผใใGPUใงใชใ้ใใไธ่จใฎใณใผใใฏใใใใOOM
ใจใฉใผใ่ฟใใใฏใใงใ๏ผใกใขใชใฎๅคงๅใฏใฏใญในใขใใณใทใงใณใฌใคใคใผใๅ ใใฆใใพใใใใฎๅฆ็ใใใใใงๅฎ่กใใไปฃใใใซใ้ๆฌกๅฎ่กใใใใจใงใกใขใชใๅคงๅน
ใซ็ฏ็ดใงใใพใใๅฟ
่ฆใชใฎใฏใ[~DiffusionPipeline.enable_attention_slicing
]้ขๆฐใไฝฟ็จใใใใจใ ใใงใ๏ผ
pipeline.enable_attention_slicing()
ไปๅบฆใฏbatch_size
ใ8ใซใใฆใฟใฆใใ ใใ๏ผ
images = pipeline(**get_inputs(batch_size=8)).images
make_image_grid(images, rows=2, cols=4)

ไปฅๅใฏ4ๆใฎ็ปๅใฎใใใใ็ๆใใใใจใใใงใใพใใใงใใใใไปใงใฏ8ๆใฎ็ปๅใฎใใใใ1ๆใใใ๏ฝ3.5็งใง็ๆใงใใพใ๏ผใใใฏใใใใใๅ่ณชใ็ ็ฒใซใใใใจใชใT4 GPUใงใงใใๆ้ใฎๅฆ็้ๅบฆใงใใ
ๅ่ณช
ๅใฎ2ใคใฎใปใฏใทใงใณใงใฏใfp16
ใไฝฟใฃใฆใใคใใฉใคใณใฎ้ๅบฆใๆ้ฉๅใใๆนๆณใใใใใใฉใผใใณ ในใชในใฑใธใฅใผใฉใผใไฝฟใฃใฆ็ๆในใใใๆฐใๆธใใๆนๆณใใขใใณใทใงใณในใฉใคในใๆๅน ใซใใฆใกใขใชๆถ่ฒป้ใๆธใใๆนๆณใซใคใใฆๅญฆใณใพใใใไปๅบฆใฏใ็ๆใใใ็ปๅใฎๅ่ณชใๅไธใใใๆนๆณใซ็ฆ็นใๅฝใฆใพใใ
ใใ่ฏใใใงใใฏใใคใณใ
ๆใๅ็ดใชในใใใใฏใใใ่ฏใใใงใใฏใใคใณใใไฝฟใใใจใงใใStable Diffusionใขใใซใฏ่ฏใๅบ็บ็นใงใใใๅ ฌๅผ็บ่กจไปฅๆฅใใใใคใใฎๆน่ฏ็ใใชใชใผในใใใฆใใพใใใใใใๆฐใใใใผใธใงใณใไฝฟใฃใใใใจใใฃใฆใ่ชๅ็ใซ่ฏใ็ตๆใๅพใใใใใใงใฏใใใพใใใๆ่ฏใฎ็ตๆใๅพใใใใซใฏใ่ชๅใงใใพใใพใชใใงใใฏใใคใณใใ่ฉฆใใฆใฟใใใใกใใฃใจใใ็ ็ฉถ๏ผใใฌใใฃใใใญใณใใใฎไฝฟ็จใชใฉ๏ผใใใใใใๅฟ ่ฆใใใใพใใ
ใใฎๅ้ใๆ้ทใใใซใคใใฆใ็นๅฎใฎในใฟใคใซใ็ใฟๅบใใใใซๅพฎ่ชฟๆดใใใใใใ่ณชใฎ้ซใใใงใใฏใใคใณใใๅขใใฆใใพใใHubใDiffusers Galleryใๆข็ดขใใฆใ่ๅณใฎใใใใฎใ่ฆใคใใฆใฟใฆใใ ใใ๏ผ
ใใ่ฏใใใคใใฉใคใณใณใณใใผใใณใ
็พๅจใฎใใคใใฉใคใณใณใณใใผใใณใใๆฐใใใใผใธใงใณใซ็ฝฎใๆใใฆใฟใใใจใใงใใพใใStability AIใๆไพใใๆๆฐใฎautodecoderใใใคใใฉใคใณใซใญใผใใใ็ปๅใ็ๆใใฆใฟใพใใใ๏ผ
from diffusers import AutoencoderKL
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
pipeline.vae = vae
images = pipeline(**get_inputs(batch_size=8)).images
make_image_grid(images, rows=2, cols=4)

ใใ่ฏใใใญใณใใใปใจใณใธใใขใชใณใฐ
็ปๅใ็ๆใใใใใซไฝฟ็จใใๆ็ซ ใฏใใใญใณใใใจใณใธใใขใชใณใฐใจๅผใฐใใๅ้ใไฝใใใใปใฉใ้ๅธธใซ้่ฆใงใใใใญใณใใใปใจใณใธใใขใชใณใฐใง่ๆ ฎใในใ็นใฏไปฅไธใฎ้ใใงใ๏ผ
- ็ๆใใใ็ปๅใใใฎ้กไผผ็ปๅใฏใใคใณใฟใผใใใไธใซใฉใฎใใใซไฟๅญใใใฆใใใ๏ผ
- ็งใๆใในใฟใคใซใซใขใใซใ่ชๅฐใใใใใซใใฉใฎใใใช่ฟฝๅ ่ฉณ็ดฐใไธใใในใใ๏ผ
ใใฎใใจใๅฟต้ ญใซ็ฝฎใใฆใใใญใณใใใซ่ฒใใใ่ณชใฎ้ซใใใฃใใผใซใๅซใใใใใซๆน่ฏใใฆใฟใพใใใ๏ผ
prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
ๆฐใใใใญใณใใใง็ปๅใฎใใใใ็ๆใใพใใใ๏ผ
images = pipeline(**get_inputs(batch_size=8)).images
make_image_grid(images, rows=2, cols=4)

ใใชใใใใงใ๏ผ็จฎใ1
ใฎGenerator
ใซๅฏพๅฟใใ2็ช็ฎใฎ็ปๅใซใ่ขซๅไฝใฎๅนด้ฝขใซ้ขใใใใญในใใ่ฟฝๅ ใใฆใใใๅฐใๆใๅ ใใฆใฟใพใใใ๏ผ
prompts = [
"portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
"portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
"portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
"portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
]
generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
make_image_grid(images, 2, 2)

ๆฌกใฎในใใใ
ใใฎใใฅใผใใชใขใซใงใฏใ[DiffusionPipeline
]ใๆ้ฉๅใใฆ่จ็ฎๅน็ใจใกใขใชๅน็ใๅไธใใใ็ๆใใใๅบๅใฎๅ่ณชใๅไธใใใๆนๆณใๅญฆใณใพใใใใใคใใฉใคใณใใใใซ้ซ้ๅใใใใจใซ่ๅณใใใใฐใไปฅไธใฎใชใฝใผในใๅ็
งใใฆใใ ใใ๏ผ
- PyTorch 2.0ใจ
torch.compile
ใใฉใฎใใใซ็ๆ้ๅบฆใ5-300%้ซ้ๅใงใใใใๅญฆใใงใใ ใใใA100 GPUใฎๅ ดๅใ็ปๅ็ๆใฏๆๅคง50%้ใใชใใพใ๏ผ - PyTorch 2ใไฝฟใใชใๅ ดๅใฏใxFormersใใคใณในใใผใซใใใใจใใๅงใใใพใใใใฎใฉใคใใฉใชใฎใกใขใชๅน็ใฎ่ฏใใขใใณใทใงใณใกใซใใบใ ใฏ PyTorch 1.13.1 ใจ็ธๆงใ่ฏใใ้ซ้ๅใจใกใขใชๆถ่ฒป้ใฎๅๆธใๅๆใซๅฎ็พใใพใใ
- ใขใใซใฎใชใใญใผใใชใฉใใใฎไปใฎๆ้ฉๅใใฏใใใฏใฏ this guide ใงใซใใผใใใฆใใพใใ