nostalgebraist commited on
Commit
67311fc
·
1 Parent(s): 8e3a852
Files changed (1) hide show
  1. app.py +20 -4
app.py CHANGED
@@ -1,9 +1,13 @@
1
  import streamlit as st
2
 
3
- import time
4
  import numpy as np
5
  from PIL import Image
6
 
 
 
 
 
7
  # constants
8
  HF_REPO_NAME_DIFFUSION = 'nostalgebraist/nostalgebraist-autoresponder-diffusion'
9
  model_path_diffusion = 'nostalgebraist-autoresponder-diffusion'
@@ -63,14 +67,23 @@ def setup():
63
  return pipeline
64
 
65
 
66
- def handler(text, ts1, ts2, gs1):
 
 
 
 
 
 
67
  pipeline = setup()
68
 
69
  data = {'text': text[:380], 'guidance_scale': gs1}
70
  args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
71
  args.update(data)
72
 
73
- print(f"running: {args}")
 
 
 
74
 
75
  pipeline.base_model.set_timestep_respacing(str(ts1))
76
  pipeline.super_res_model.set_timestep_respacing(str(ts2))
@@ -121,6 +134,7 @@ low_res = st.empty()
121
  high_res = st.empty()
122
 
123
  if button_go:
 
124
  with generating_marker.container():
125
  st.write('**Generating...**')
126
  st.write('**Prompt:**')
@@ -131,7 +145,7 @@ if button_go:
131
 
132
  t = time.time()
133
 
134
- for s, xs in handler(text, ts1, ts2, gs1):
135
  s = Image.fromarray(s[0])
136
  xs = Image.fromarray(xs[0])
137
 
@@ -165,7 +179,9 @@ if button_go:
165
  st.write(f'{prefix} | {count:02d} / {total} frames | {rate:.2f} seconds/frame')
166
 
167
  if button_stop:
 
168
  break
169
 
170
  with generating_marker.container():
 
171
  st.write('')
 
1
  import streamlit as st
2
 
3
+ import time, uuid
4
  import numpy as np
5
  from PIL import Image
6
 
7
+ if 'session_id' not in st.session_state:
8
+ st.session_state.session_id = str(uuid.uuid4())
9
+ st.session_state.n_gen = 0
10
+
11
  # constants
12
  HF_REPO_NAME_DIFFUSION = 'nostalgebraist/nostalgebraist-autoresponder-diffusion'
13
  model_path_diffusion = 'nostalgebraist-autoresponder-diffusion'
 
67
  return pipeline
68
 
69
 
70
+ def log(msg, st_state):
71
+ session_id = st_state.session_id if 'session_id' in st_state else None
72
+ n_gen = st.session_state.n_gen if 'n_gen' in st_state else None
73
+
74
+ print(f"{session_id} ({n_gen}th gen):\n\t{msg}\n")
75
+
76
+ def handler(text, ts1, ts2, gs1, st_state):
77
  pipeline = setup()
78
 
79
  data = {'text': text[:380], 'guidance_scale': gs1}
80
  args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
81
  args.update(data)
82
 
83
+ log_data = {'ts1': ts2, 'ts2': ts2}
84
+ log_data.update(args)
85
+
86
+ log(repr(log_data), st_state))
87
 
88
  pipeline.base_model.set_timestep_respacing(str(ts1))
89
  pipeline.super_res_model.set_timestep_respacing(str(ts2))
 
134
  high_res = st.empty()
135
 
136
  if button_go:
137
+ st.session_state.n_gen = st.session_state.n_gen + 1
138
  with generating_marker.container():
139
  st.write('**Generating...**')
140
  st.write('**Prompt:**')
 
145
 
146
  t = time.time()
147
 
148
+ for s, xs in handler(text, ts1, ts2, gs1, st.session_state):
149
  s = Image.fromarray(s[0])
150
  xs = Image.fromarray(xs[0])
151
 
 
179
  st.write(f'{prefix} | {count:02d} / {total} frames | {rate:.2f} seconds/frame')
180
 
181
  if button_stop:
182
+ log('gen stopped', st_state))
183
  break
184
 
185
  with generating_marker.container():
186
+ log('gen complete', st_state))
187
  st.write('')