Spaces:
Runtime error
Runtime error
| import diffusers | |
| import torch | |
| import os | |
| import time | |
| import accelerate | |
| import streamlit as st | |
| from stqdm import stqdm | |
| from diffusers import DiffusionPipeline, UNet2DConditionModel | |
| from PIL import Image | |
| MODEL_REPO = 'OFA-Sys/small-stable-diffusion-v0' | |
| LoRa_DIR = 'weights' | |
| DATASET_REPO = 'VESSL/Bored_Ape_NFT_text' | |
| SAMPLE_IMAGE = 'weights/Sample.png' | |
| def load_pipeline_w_lora() : | |
| # Load pretrained unet from huggingface | |
| unet = UNet2DConditionModel.from_pretrained( | |
| MODEL_REPO, | |
| subfolder="unet", | |
| revision=None | |
| ) | |
| # Load LoRa attn layer weights to unet attn layers | |
| print('LoRa layers loading...') | |
| unet.load_attn_procs(LoRa_DIR) | |
| print('LoRa layers loaded') | |
| # Load pipeline | |
| pipeline = DiffusionPipeline.from_pretrained( | |
| MODEL_REPO, | |
| unet=unet, | |
| revision=None, | |
| torch_dtype=torch.float32, | |
| ) | |
| pipeline.set_progress_bar_config(disable=True) | |
| return pipeline | |
| def elapsed_time(fn, *args): | |
| start = time.time() | |
| output = fn(*args) | |
| end = time.time() | |
| elapsed = f'{end - start:.2f}' | |
| return elapsed, output | |
| def main(): | |
| st.title("BAYC Text to IMAGE generator") | |
| st.write(f"Stable diffusion model is fine-tuned by lora using dataset {DATASET_REPO}") | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| st.write("Loading models...") | |
| elapsed, pipeline = elapsed_time(load_pipeline_w_lora) | |
| st.write(f"Model is loaded in {elapsed} seconds!") | |
| pipeline = pipeline.to(device) | |
| sample = Image.open(SAMPLE_IMAGE) | |
| st.image(sample, caption="Example image with prompt <An ape with solid gold fur and beanie>") | |
| with st.form(key="information", clear_on_submit=True): | |
| prompt = st.text_input( | |
| label="Write prompt to generate your unique BAYC image! (e.g. An ape with golden fur)") | |
| num_images = st.number_input(label="Number of images to generate", min_value=1, max_value=10) | |
| seed = st.number_input(label="Seed for images", min_value=1, max_value=10000) | |
| submitted = st.form_submit_button(label="Submit") | |
| if submitted : | |
| st.write(f"Generating {num_images} BAYC image with prompt <{prompt}>...") | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| images = [] | |
| for img_idx in stqdm(range(num_images)): | |
| generated_image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0] | |
| images.append(generated_image) | |
| st.write("Done!") | |
| st.image(images, width=150, caption=f"Generated Images with <{prompt}>") | |
| if __name__ == '__main__': | |
| main() | |