Spaces:
Runtime error
Runtime error
File size: 7,452 Bytes
b7518ad 380e15d a54a213 67311fc a54a213 ba1e299 a54a213 e4b5feb f0e64d7 910bf72 f0e64d7 7943576 f0e64d7 a54a213 1b3b3ce 5b09180 67311fc 5b09180 67311fc 7cfb5aa 7943576 67311fc 22978b0 7943576 a54a213 ef0202a 693beab a54a213 d77e462 2d1fef4 ef0202a d77e462 e6d4f87 fd729fa d922a0c b9602b5 ef0202a 7fd7245 b9602b5 12981b8 d922a0c 1b3b3ce bb738df 5a8cc6b ef0202a bb738df 5a8cc6b 8e3a852 bb738df 693beab 53273e5 876d9a1 491d9dc 693beab 67311fc 491d9dc b9602b5 3753600 910bf72 4ef7cbc 789fb56 1c987b2 2d98384 4ef7cbc 67311fc b43ef12 4a15936 1c987b2 4ef7cbc 4a15936 4ef7cbc 4a15936 4ef7cbc 1c987b2 b9602b5 4ef7cbc 4a15936 4ef7cbc 1c987b2 b9602b5 1c987b2 4a15936 b43ef12 ab4b236 876d9a1 693beab 3eccda4 693beab 491d9dc 3eccda4 491d9dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import streamlit as st
import time, uuid
from datetime import datetime
import numpy as np
from PIL import Image
if 'session_id' not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
st.session_state.n_gen = 0
# constants
HF_REPO_NAME_DIFFUSION = 'nostalgebraist/nostalgebraist-autoresponder-diffusion'
model_path_diffusion = 'nostalgebraist-autoresponder-diffusion'
timestep_respacing_sres1 = '20' # '90,60,60,20,20'
timestep_respacing_sres2 = '20' # '250'
DIFFUSION_DEFAULTS = dict(
batch_size=1,
n_samples=1,
clf_free_guidance=True,
clf_free_guidance_sres=False,
guidance_scale=1,
guidance_scale_sres=0,
yield_intermediates=True
)
@st.experimental_singleton
def setup():
import os, subprocess, sys
if not os.path.exists('improved_diffusion'):
os.system("git clone https://github.com/nostalgebraist/improved-diffusion.git")
os.system("cd improved-diffusion && git fetch origin nbar-space && git checkout nbar-space && pip install -e .")
os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
os.system("pip install einops==0.3.2")
sys.path.append("improved-diffusion")
import improved_diffusion.pipeline
from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
if not os.path.exists(model_path_diffusion):
model_tar_name = 'model.tar'
model_tar_path = get_local_path_from_huggingface_cdn(
HF_REPO_NAME_DIFFUSION, model_tar_name
)
subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
# load
sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
checkpoint_path=checkpoint_path_sres1,
config_path=config_path_sres1,
timestep_respacing=timestep_respacing_sres1
)
sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
checkpoint_path=checkpoint_path_sres2,
config_path=config_path_sres2,
timestep_respacing=timestep_respacing_sres2
)
pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
return pipeline
def now_str():
return datetime.utcnow().strftime('%Y-%m-%d %H-%M-%S')
def log(msg, st_state):
session_id = st_state.session_id if 'session_id' in st_state else None
n_gen = st.session_state.n_gen if 'n_gen' in st_state else None
print(f"{now_str()} {session_id} ({n_gen}th gen):\n\t{msg}\n")
def handler(text, ts1, ts2, gs1, st_state):
pipeline = setup()
data = {'text': text[:380], 'guidance_scale': gs1}
args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
args.update(data)
log_data = {'ts1': ts2, 'ts2': ts2}
log_data.update(args)
log(repr(log_data), st_state)
pipeline.base_model.set_timestep_respacing(str(ts1))
pipeline.super_res_model.set_timestep_respacing(str(ts2))
return pipeline.sample(**args)
FRESH = True
st.title('nostalgebraist-autoresponder image generation demo')
st.write("#### For a **much faster experience**, try the [Colab notebook](https://colab.research.google.com/drive/17BOTYmLv4fdurr8y5dcaGKy8JVY_A62a?usp=sharing) instead!")
st.write("A demo of the image models used in the tumblr bot [nostalgebraist-autoresponder](https://nostalgebraist-autoresponder.tumblr.com/).\n\nBy [nostalgebraist](https://nostalgebraist.tumblr.com/)")
st.write('##### What is this thing? How does it work?')
st.write("See [this post](https://nostalgebraist.tumblr.com/post/672300992964050944/franks-image-generation-model-explained) for an explanation.")
st.header('Prompt')
button_dril = st.button('Fill @dril tweet example text')
if FRESH and button_dril:
st.session_state.fill_value = 'wint\nFollowing\n@dril\nthe wise man bowed his head solemnly and\nspoke: "theres actually zero difference\nbetween good & bad things. you imbecile.\nyou fucking moron'
if 'fill_value' in st.session_state:
fill_value = st.session_state.fill_value
else:
fill_value = ""
st.session_state.fill_value = fill_value
text = st.text_area('Enter your text here (or leave blank for a textless image)', max_chars=380, height=230,
value=fill_value)
st.header('Settings')
st.write("The bot uses 250 base steps and 250 upsampling steps, with custom spacing (not available here) for the base part.\n\nSince this demo doesn't have a GPU, you'll probably want to use fewer than 250 steps unles you have a lot of patience.")
help_ts1 = "How long to run the base model. Larger values make the image more realistic / better. Smaller values are faster."
help_ts2 = "How long to run the upsampling model. Larger values sometimes make the big image crisper and more detailed. Smaller values are faster."
help_gs1 = "Guidance scale. Larger values make the image more likely to contain the text you wrote. If this is zero, the first part will be faster."
ts1 = st.slider('Steps (base)', min_value=5, max_value=500, value=50, step=5, help=help_ts1)
ts2 = st.slider('Steps (upsampling)', min_value=5, max_value=500, value=50, step=5, help=help_ts2)
gs1 = st.select_slider('Guidance scale (base)', [0.5*i for i in range(9)], value=1.0, help=help_gs1)
button_go = st.button('Generate')
button_stop = st.button('Stop')
st.write("During generation, the two images show different ways of looking at the same process.\n- The left image starts with 100% noise and gradually turns into 100 signal.\n- The right image shows the model's current 'guess' about what the left image will look like when all the noise has been removed.")
generating_marker = st.empty()
low_res = st.empty()
high_res = st.empty()
if button_go:
st.session_state.n_gen = st.session_state.n_gen + 1
with generating_marker.container():
st.write('**Generating...**')
st.write('**Prompt:**')
st.write(repr(text))
count_low_res, count_high_res = 0, 0
times_low, times_high = [], []
t = time.time()
for s, xs in handler(text, ts1, ts2, gs1, st.session_state):
s = Image.fromarray(s[0])
xs = Image.fromarray(xs[0])
t2 = time.time()
delta = t2 - t
t = t2
is_high_res = s.size[0] == 256
if is_high_res:
target = high_res
count_high_res += 1
count = count_high_res
total = ts2
times_high.append(delta)
times = times_high
prefix = "Part 2 of 2 (upsampling)"
else:
target = low_res
count_low_res += 1
count = count_low_res
total = ts1
times_low.append(delta)
times = times_low
prefix = "Part 1 of 2 (base model)"
rate = sum(times)/len(times)
with target.container():
st.image([s, xs])
st.write(f'{prefix} | {count:02d} / {total} frames | {rate:.2f} seconds/frame')
if button_stop:
log('gen stopped', st.session_state)
break
with generating_marker.container():
log('gen complete', st.session_state)
st.write('')
|