treasuraid's picture
Update app.py
ff92cb1
raw
history blame
2.54 kB
import diffusers
import torch
import os
import time
import accelerate
import streamlit as st
from stqdm import stqdm
from diffusers import DiffusionPipeline, UNet2DConditionModel
from PIL import Image
MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0'
LoRa_DIR = 'weights'
DATASET_REPO = 'VESSL/Bored_Ape_NFT_text'
SAMPLE_IMAGE = 'weights/Sample.png'
def load_pipeline_w_lora() :
# Load pretrained unet from huggingface
unet = UNet2DConditionModel.from_pretrained(
MODEL_REPO,
subfolder="unet",
revision=None
)
# Load LoRa attn layer weights to unet attn layers
unet.load_attn_procs(LoRa_DIR)
print('loaded')
# Load pipeline
pipeline = DiffusionPipeline.from_pretrained(
MODEL_REPO,
unet=unet,
revision=None,
torch_dtype=torch.float32,
)
pipeline.set_progress_bar_config(disable=True)
return pipeline
def elapsed_time(fn, *args):
start = time.time()
output = fn(*args)
end = time.time()
elapsed = f'{end - start:.2f}'
return elapsed, output
def main():
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
elapsed, pipeline = elapsed_time(load_pipeline_w_lora)
pipeline = pipeline.to(device)
st.write(f"Model is loaded in {elapsed} seconds!")
st.title("BAYC Text to IMAGE generator")
st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}")
sample = Image.open(SAMPLE_IMAGE)
st.image(sample, caption="Example image with prompt <An ape with solid gold fur and beanie>")
with st.form(key="information", clear_on_submit=True):
prompt = st.text_input(
label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)")
num_images = st.number_input(label="Number of images to generate")
seed = st.number_input(label="Seed for images")
submitted = st.form_submit_button(label="Submit")
if submitted :
st.write(f"Generating {num_images} BAYC image with prompt <{prompt}>...")
generator = torch.Generator(device=device).manual_seed(seed)
images = []
for img_idx in stqdm(range(num_images)):
generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
images.append(generated_image)
st.write("Done!")
st.image(images, width=150, caption=f"Generated Images with <{prompt}>")
if __name__ == '__main__':
main()