tools / save_lora.py
patrickvonplaten's picture
rename
27dfa17
#!/usr/bin/env python3
import torch
from warnings import warn
from diffusers import (
AutoencoderKL,
DiffusionPipeline,
)
import hashlib
base = "stabilityai/stable-diffusion-xl-base-1.0"
adapter1 = 'nerijs/pixel-art-xl'
weightname1 = 'pixel-art-xl.safetensors'
adapter2 = 'Alexzyx/lora-trained-xl-colab'
weightname2 = None
inputs = "elephant"
kwargs = {}
if torch.cuda.is_available():
kwargs["torch_dtype"] = torch.float16
#vae = AutoencoderKL.from_pretrained(
# "madebyollin/sdxl-vae-fp16-fix",
# torch_dtype=torch.float16, # load fp16 fix VAE
#)
#kwargs["vae"] = vae
#kwargs["variant"] = "fp16"
#
model = DiffusionPipeline.from_pretrained(
base, **kwargs
)
if torch.cuda.is_available():
model.to("cuda")
def inference(adapter, weightname):
model.load_lora_weights(adapter, weight_name=weightname)
try:
model.fuse_lora(safe_fusing=True)
except ValueError:
warn(f"{adapter} and {weightname} is broken. LoRA is not fused.")
model.unload_lora_weights()
data = model(inputs, num_inference_steps=1).images[0]
model.unfuse_lora()
model.unload_lora_weights()
filename = '/tmp/hello.jpg'
data.save(filename, format='jpeg')
with open(filename, 'rb') as f:
md5 = hashlib.md5(f.read()).hexdigest()
print("Adapter %s, md5sum %s" % (adapter, md5))
if md5 == '40c78c9fd4daeff01c988c3532fdd51b':
print("BLACK SCREEN IMAGE for adapter %s" % adapter)
inference(adapter1, weightname1)
inference(adapter2, weightname2)
inference(adapter1, weightname1)
inference(adapter1, weightname1)