Spaces:
Runtime error
Runtime error
File size: 5,791 Bytes
b7518ad ba1e299 a54a213 ba1e299 a54a213 e4b5feb f0e64d7 910bf72 f0e64d7 7943576 f0e64d7 a54a213 1b3b3ce a54a213 7cfb5aa 7943576 a54a213 693beab a54a213 fd729fa 1b3b3ce ba1e299 e5509f2 4ef7cbc fd729fa d922a0c b9602b5 d922a0c 1b3b3ce 693beab 3957af8 876d9a1 491d9dc 693beab 491d9dc b9602b5 910bf72 4ef7cbc 789fb56 1c987b2 2d98384 4ef7cbc b43ef12 4a15936 1c987b2 4ef7cbc 4a15936 4ef7cbc 4a15936 4ef7cbc 1c987b2 b9602b5 4ef7cbc 4a15936 4ef7cbc 1c987b2 b9602b5 1c987b2 4a15936 b43ef12 ab4b236 876d9a1 693beab 491d9dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import streamlit as st
import time
import numpy as np
from PIL import Image
# constants
HF_REPO_NAME_DIFFUSION = 'nostalgebraist/nostalgebraist-autoresponder-diffusion'
model_path_diffusion = 'nostalgebraist-autoresponder-diffusion'
timestep_respacing_sres1 = '20' # '90,60,60,20,20'
timestep_respacing_sres2 = '20' # '250'
DIFFUSION_DEFAULTS = dict(
batch_size=1,
n_samples=1,
clf_free_guidance=True,
clf_free_guidance_sres=False,
guidance_scale=1,
guidance_scale_sres=0,
yield_intermediates=True
)
@st.experimental_singleton
def setup():
import os, subprocess, sys
if not os.path.exists('improved_diffusion'):
os.system("git clone https://github.com/nostalgebraist/improved-diffusion.git")
os.system("cd improved-diffusion && git fetch origin nbar-space && git checkout nbar-space && pip install -e .")
os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
os.system("pip install einops==0.3.2")
sys.path.append("improved-diffusion")
import improved_diffusion.pipeline
from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
if not os.path.exists(model_path_diffusion):
model_tar_name = 'model.tar'
model_tar_path = get_local_path_from_huggingface_cdn(
HF_REPO_NAME_DIFFUSION, model_tar_name
)
subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
# load
sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
checkpoint_path=checkpoint_path_sres1,
config_path=config_path_sres1,
timestep_respacing=timestep_respacing_sres1
)
sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
checkpoint_path=checkpoint_path_sres2,
config_path=config_path_sres2,
timestep_respacing=timestep_respacing_sres2
)
pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
return pipeline
def handler(text, ts1, ts2, gs1):
pipeline = setup()
data = {'text': text[:380], 'guidance_scale': gs1}
args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
args.update(data)
print(f"running: {args}")
pipeline.base_model.set_timestep_respacing(str(ts1))
pipeline.super_res_model.set_timestep_respacing(str(ts2))
return pipeline.sample(**args)
st.title('nostalgebraist-autoresponder image generation demo')
st.header('Settings')
help_ts1 = "How long to run the base model. Larger values make the image more realistic / better. Smaller values are faster."
help_ts2 = "How long to run the upsampling model. Larger values sometimes make the big image crisper and more detailed. Smaller values are faster."
help_gs1 = "Guidance scale. Larger values make the image more likely to contain the text you wrote. If this is zero, the first part will be faster."
ts1 = st.slider('Steps (base)', min_value=5, max_value=500, value=10, help=help_ts1)
ts2 = st.slider('Steps (upsampling)', min_value=5, max_value=500, value=10, help=help_ts2)
gs1 = st.select_slider('Guidance scale (base)', [0.5*i for i in range(9)], value=0., help=help_gs1)
st.header('Prompt')
button_dril = st.button('Fill @dril tweet example text')
if 'fill_value' in st.session_state:
fill_value = st.session_state.fill_value
else:
fill_value = ""
if button_dril:
fill_value = 'wint\nFollowing\n@dril\nthe wise man bowed his head solemnly and\nspoke: "theres actually zero difference\nbetween good & bad things. you imbecile.\nyou fucking moron'
text = st.text_area('Enter your text here (or leave blank for a textless image)', max_chars=380, height=230,
value=fill_value)
button_go = st.button('Generate')
button_stop = st.button('Stop')
st.write("During generation, the two images show different ways of looking at the same process.\n- The left image starts with 100% noise and gradually turns into 100 signal.\n- The right image shows the model's current 'guess' about the left image will look like when all the noise has been removed.")
generating_marker = st.empty()
low_res = st.empty()
high_res = st.empty()
if button_go:
with generating_marker.container():
st.write('**Generating...**')
st.write('**Prompt:**')
st.write(text)
count_low_res, count_high_res = 0, 0
times_low, times_high = [], []
t = time.time()
for s, xs in handler(text, ts1, ts2, gs1):
s = Image.fromarray(s[0])
xs = Image.fromarray(xs[0])
t2 = time.time()
delta = t2 - t
t = t2
is_high_res = s.size[0] == 256
if is_high_res:
target = high_res
count_high_res += 1
count = count_high_res
total = ts2
times_high.append(delta)
times = times_high
prefix = "Part 2 of 2 (upsampling)"
else:
target = low_res
count_low_res += 1
count = count_low_res
total = ts1
times_low.append(delta)
times = times_low
prefix = "Part 1 of 2 (base model)"
rate = sum(times)/len(times)
with target.container():
st.image([s, xs])
st.write(f'{prefix} | {count:02d} / {total} frames | {rate:.2f} seconds/frame')
if button_stop:
break
with generating_marker.container():
st.write('')
|