Spaces:
Runtime error
Runtime error
add app.py
Browse files
app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from diffusers import StableDiffusionPipeline
|
4 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
5 |
+
|
6 |
+
def void(*args, **kwargs):
|
7 |
+
pass
|
8 |
+
|
9 |
+
st.title("AI 元火娘")
|
10 |
+
|
11 |
+
with st.sidebar:
|
12 |
+
model = st.selectbox("Model Name", [
|
13 |
+
"wybxc/yanhuo-v1-dreambooth",
|
14 |
+
"wybxc/yanyuan-v1-dreambooth",
|
15 |
+
"wybxc/yuanhuo-v1-dreambooth",
|
16 |
+
"<Custom>"
|
17 |
+
])
|
18 |
+
if model == "<Custom>":
|
19 |
+
model = st.text_input("Model Path", "").strip()
|
20 |
+
|
21 |
+
# Caching model
|
22 |
+
if 'model' not in st.session_state:
|
23 |
+
st.session_state.model = model
|
24 |
+
if 'pipeline' not in st.session_state:
|
25 |
+
st.session_state.pipeline = None
|
26 |
+
|
27 |
+
|
28 |
+
if model != st.session_state.model or st.session_state.pipeline is None:
|
29 |
+
if model:
|
30 |
+
with st.spinner("Loading Model..."):
|
31 |
+
pipeline = StableDiffusionPipeline.from_pretrained(
|
32 |
+
model,
|
33 |
+
torch_dtype=torch.float16
|
34 |
+
)
|
35 |
+
assert type(pipeline) is StableDiffusionPipeline
|
36 |
+
if torch.cuda.is_available():
|
37 |
+
pipeline = pipeline.to("cuda")
|
38 |
+
st.session_state.model = model
|
39 |
+
st.session_state.pipeline = pipeline
|
40 |
+
else:
|
41 |
+
pipeline = None
|
42 |
+
else:
|
43 |
+
pipeline = st.session_state.pipeline
|
44 |
+
assert type(pipeline) is StableDiffusionPipeline
|
45 |
+
|
46 |
+
|
47 |
+
prompt = st.text_area("Prompt", "(yanhuo), 1girl, masterpiece, best quality, "
|
48 |
+
"white hair, ahoge, snowy street, [smile], dynamic angle, full body, "
|
49 |
+
"[blue eyes], flat chest, cinematic light")
|
50 |
+
|
51 |
+
negative_prompt = st.text_area("Negative Prompt", "lowres, bad anatomy, bad hands, "
|
52 |
+
"text, error, missing fingers, extra digit, fewer digits, cropped, "
|
53 |
+
"worst quality, low quality, normal quality, jpeg artifacts, signature, "
|
54 |
+
"watermark, username, blurry")
|
55 |
+
|
56 |
+
with st.sidebar:
|
57 |
+
height = st.slider("Height", 256, 1024, 512, 64)
|
58 |
+
width = st.slider("Width", 256, 1024, 512, 64)
|
59 |
+
steps = st.slider("Steps", 1, 100, 20, 1)
|
60 |
+
|
61 |
+
if pipeline and st.button("Generate"):
|
62 |
+
progress = st.progress(0)
|
63 |
+
result = pipeline(
|
64 |
+
prompt=prompt,
|
65 |
+
negative_prompt=negative_prompt,
|
66 |
+
height=height,
|
67 |
+
width=width,
|
68 |
+
num_inference_steps=steps,
|
69 |
+
callback=lambda s, *_: void(progress.progress(s / steps))
|
70 |
+
)
|
71 |
+
assert type(result) is StableDiffusionPipelineOutput
|
72 |
+
image = result.images[0]
|
73 |
+
|
74 |
+
progress.progress(1.0)
|
75 |
+
|
76 |
+
st.image(image)
|