Spaces:
Sleeping
Sleeping
File size: 2,907 Bytes
f4ee239 d69cde5 1b5361f f4ee239 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import streamlit as st
from io import BytesIO
from typing import Literal
from diffusers import StableDiffusionPipeline
import torch
import time
seed = 42
generator = torch.manual_seed(seed)
NUM_ITERS_TO_RUN = 2
NUM_INFERENCE_STEPS = 20
NUM_IMAGES_PER_PROMPT = 1
def text2image(
prompt: str,
repo_id: Literal[
"dreamlike-art/dreamlike-photoreal-2.0",
"hakurei/waifu-diffusion",
"prompthero/openjourney",
"stabilityai/stable-diffusion-2-1",
"runwayml/stable-diffusion-v1-5",
"nota-ai/bk-sdm-small",
"CompVis/stable-diffusion-v1-4",
],
):
start = time.time()
if torch.cuda.is_available():
print("Using GPU")
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
else:
print("Using CPU")
pipeline = StableDiffusionPipeline.from_pretrained(
repo_id,
torch_dtype=torch.float32,
use_safetensors=True,
)
for _ in range(NUM_ITERS_TO_RUN):
images = pipeline(
prompt,
num_inference_steps=NUM_INFERENCE_STEPS,
generator=generator,
num_images_per_prompt=NUM_IMAGES_PER_PROMPT,
).images
end = time.time()
return images[0], start, end
def app():
st.header("Text-to-image Web App")
st.subheader("Powered by Hugging Face")
user_input = st.text_area(
"Enter your text prompt below and click the button to submit."
)
option = st.selectbox(
"Select model (in order of processing time)",
(
"nota-ai/bk-sdm-small",
"CompVis/stable-diffusion-v1-4",
"runwayml/stable-diffusion-v1-5",
"prompthero/openjourney",
"hakurei/waifu-diffusion",
"stabilityai/stable-diffusion-2-1",
"dreamlike-art/dreamlike-photoreal-2.0",
),
)
with st.form("my_form"):
submit = st.form_submit_button(label="Submit text prompt")
if submit:
with st.spinner(text="Generating image ... It may take up to 20 minutes."):
im, start, end = text2image(prompt=user_input, repo_id=option)
buf = BytesIO()
im.save(buf, format="PNG")
byte_im = buf.getvalue()
hours, rem = divmod(end - start, 3600)
minutes, seconds = divmod(rem, 60)
st.success(
"Processing time: {:0>2}:{:0>2}:{:05.2f}.".format(
int(hours), int(minutes), seconds
)
)
st.image(im)
st.download_button(
label="Click here to download",
data=byte_im,
file_name="generated_image.png",
mime="image/png",
)
if __name__ == "__main__":
app() |