Spaces:
Runtime error
Runtime error
File size: 3,994 Bytes
b7518ad ba1e299 a54a213 ba1e299 a54a213 e4b5feb f0e64d7 910bf72 f0e64d7 7943576 f0e64d7 a54a213 7cfb5aa 7943576 a54a213 ba1e299 a54a213 ba1e299 4a15936 4ef7cbc a54a213 910bf72 4ef7cbc b43ef12 4a15936 4ef7cbc 4a15936 4ef7cbc 4a15936 4ef7cbc 4a15936 4ef7cbc 4a15936 b43ef12 ab4b236 4ef7cbc a54a213 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import streamlit as st
import time
import numpy as np
from PIL import Image
# constants
HF_REPO_NAME_DIFFUSION = 'nostalgebraist/nostalgebraist-autoresponder-diffusion'
model_path_diffusion = 'nostalgebraist-autoresponder-diffusion'
timestep_respacing_sres1 = '20' # '90,60,60,20,20'
timestep_respacing_sres2 = '20' # '250'
DIFFUSION_DEFAULTS = dict(
batch_size=1,
n_samples=1,
clf_free_guidance=True,
clf_free_guidance_sres=False,
guidance_scale=1,
guidance_scale_sres=0,
yield_intermediates=True
)
@st.experimental_singleton
def setup():
import os, subprocess, sys
if not os.path.exists('improved_diffusion'):
os.system("git clone https://github.com/nostalgebraist/improved-diffusion.git")
os.system("cd improved-diffusion && git fetch origin nbar-space && git checkout nbar-space && pip install -e .")
os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
os.system("pip install einops==0.3.2")
sys.path.append("improved-diffusion")
import improved_diffusion.pipeline
from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
if not os.path.exists(model_path_diffusion):
model_tar_name = 'model.tar'
model_tar_path = get_local_path_from_huggingface_cdn(
HF_REPO_NAME_DIFFUSION, model_tar_name
)
subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
# load
sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
checkpoint_path=checkpoint_path_sres1,
config_path=config_path_sres1,
timestep_respacing=timestep_respacing_sres1
)
sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
checkpoint_path=checkpoint_path_sres2,
config_path=config_path_sres2,
timestep_respacing=timestep_respacing_sres2
)
pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
return pipeline
def handler(text, ts1, ts2, gs1):
pipeline = setup()
# a = np.random.randint(0, 255, (128, 128, 3)).astype(np.uint8)
data = {'text': text[:380], 'guidance_scale': gs1}
args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
args.update(data)
print(f"running: {args}")
pipeline.base_model.set_timestep_respacing(str(ts1))
pipeline.super_res_model.set_timestep_respacing(str(ts2))
return pipeline.sample(**args)
text = st.text_area('Enter your text here (or leave blank for a textless image)', max_chars=380)
help_ts1 = "foo"
help_ts2 = "bar " * 40
help_gs1 = "aaff"
ts1 = st.slider('Steps (base)', min_value=5, max_value=500, value=10, help=help_ts1)
ts2 = st.slider('Steps (upsampling)', min_value=5, max_value=500, value=10, help=help_ts1)
gs1 = st.slider('Guidance scale (base)', min_value=0., max_value=4., value=0., help=help_gs1)
# ts1, ts2, gs1 = 20, 20, 0
if st.button('rweerew'):
low_res = st.empty()
high_res = st.empty()
count_low_res, count_high_res = 0, 0
for s, xs in handler(text, ts1, ts2, gs1):
s = Image.fromarray(s[0])
xs = Image.fromarray(xs[0])
is_high_res = s.size[0] == 256
if is_high_res:
target = high_res
count_high_res += 1
count = count_high_res
total = ts2
else:
target = low_res
count_low_res += 1
count = count_low_res
total = ts1
with target.container():
st.image([s, xs])
st.write(f'{count} / {total}')
# x = st.slider('Select a value')
# st.write(x, 'squared is', x * x)
|