Spaces:
Runtime error
Runtime error
File size: 2,548 Bytes
e83adaa dc6157e e83adaa d1ca92a e83adaa be419c1 e83adaa be419c1 a10c3fd e83adaa fc3e1a1 e83adaa 1e6b2bf e83adaa 8124b54 2f8a733 e83adaa 8124b54 e83adaa 2f8a733 e83adaa a10c3fd ff92cb1 a10c3fd e83adaa a10c3fd 2f8a733 e83adaa d1ca92a e83adaa 2f8a733 e83adaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import diffusers
import torch
import os
import time
import accelerate
import streamlit as st
from stqdm import stqdm
from diffusers import DiffusionPipeline, UNet2DConditionModel
from PIL import Image
MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0'
LoRa_DIR = 'weights'
DATASET_REPO = 'VESSL/Bored_Ape_NFT_text'
SAMPLE_IMAGE = 'weights/Sample.png'
def load_pipeline_w_lora() :
# Load pretrained unet from huggingface
unet = UNet2DConditionModel.from_pretrained(
MODEL_REPO,
subfolder="unet",
revision=None
)
# Load LoRa attn layer weights to unet attn layers
unet.load_attn_procs(LoRa_DIR)
print('loaded')
# Load pipeline
pipeline = DiffusionPipeline.from_pretrained(
MODEL_REPO,
unet=unet,
revision=None,
torch_dtype=torch.float32,
)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def elapsed_time(fn, *args):
start = time.time()
output = fn(*args)
end = time.time()
elapsed = f'{end - start:.2f}'
return elapsed, output
def main():
st.write(f"Model is loaded in {elapsed} seconds!")
st.title("BAYC Text to IMAGE generator")
st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
elapsed, pipeline = elapsed_time(load_pipeline_w_lora)
pipeline = pipeline.to(device)
sample = Image.open(SAMPLE_IMAGE)
st.image(sample, caption="Example image with prompt <An ape with solid gold fur and beanie>")
with st.form(key="information", clear_on_submit=True):
prompt = st.text_input(
label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)")
num_images = st.number_input(label="Number of images to generate")
seed = st.number_input(label="Seed for images")
submitted = st.form_submit_button(label="Submit")
if submitted :
st.write(f"Generating {num_images} BAYC image with prompt <{prompt}>...")
generator = torch.Generator(device=device).manual_seed(seed)
images = []
for img_idx in stqdm(range(num_images)):
generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
images.append(generated_image)
st.write("Done!")
st.image(images, width=150, caption=f"Generated Images with <{prompt}>")
if __name__ == '__main__':
main()
|