wybxc commited on
Commit
6711464
·
unverified ·
1 Parent(s): 4892d46

add app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from diffusers import StableDiffusionPipeline
4
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
5
+
6
+ def void(*args, **kwargs):
7
+ pass
8
+
9
+ st.title("AI 元火娘")
10
+
11
+ with st.sidebar:
12
+ model = st.selectbox("Model Name", [
13
+ "wybxc/yanhuo-v1-dreambooth",
14
+ "wybxc/yanyuan-v1-dreambooth",
15
+ "wybxc/yuanhuo-v1-dreambooth",
16
+ "<Custom>"
17
+ ])
18
+ if model == "<Custom>":
19
+ model = st.text_input("Model Path", "").strip()
20
+
21
+ # Caching model
22
+ if 'model' not in st.session_state:
23
+ st.session_state.model = model
24
+ if 'pipeline' not in st.session_state:
25
+ st.session_state.pipeline = None
26
+
27
+
28
+ if model != st.session_state.model or st.session_state.pipeline is None:
29
+ if model:
30
+ with st.spinner("Loading Model..."):
31
+ pipeline = StableDiffusionPipeline.from_pretrained(
32
+ model,
33
+ torch_dtype=torch.float16
34
+ )
35
+ assert type(pipeline) is StableDiffusionPipeline
36
+ if torch.cuda.is_available():
37
+ pipeline = pipeline.to("cuda")
38
+ st.session_state.model = model
39
+ st.session_state.pipeline = pipeline
40
+ else:
41
+ pipeline = None
42
+ else:
43
+ pipeline = st.session_state.pipeline
44
+ assert type(pipeline) is StableDiffusionPipeline
45
+
46
+
47
+ prompt = st.text_area("Prompt", "(yanhuo), 1girl, masterpiece, best quality, "
48
+ "white hair, ahoge, snowy street, [smile], dynamic angle, full body, "
49
+ "[blue eyes], flat chest, cinematic light")
50
+
51
+ negative_prompt = st.text_area("Negative Prompt", "lowres, bad anatomy, bad hands, "
52
+ "text, error, missing fingers, extra digit, fewer digits, cropped, "
53
+ "worst quality, low quality, normal quality, jpeg artifacts, signature, "
54
+ "watermark, username, blurry")
55
+
56
+ with st.sidebar:
57
+ height = st.slider("Height", 256, 1024, 512, 64)
58
+ width = st.slider("Width", 256, 1024, 512, 64)
59
+ steps = st.slider("Steps", 1, 100, 20, 1)
60
+
61
+ if pipeline and st.button("Generate"):
62
+ progress = st.progress(0)
63
+ result = pipeline(
64
+ prompt=prompt,
65
+ negative_prompt=negative_prompt,
66
+ height=height,
67
+ width=width,
68
+ num_inference_steps=steps,
69
+ callback=lambda s, *_: void(progress.progress(s / steps))
70
+ )
71
+ assert type(result) is StableDiffusionPipelineOutput
72
+ image = result.images[0]
73
+
74
+ progress.progress(1.0)
75
+
76
+ st.image(image)