nostalgebraist commited on
Commit
f0e64d7
·
1 Parent(s): 5bac5ef
Files changed (1) hide show
  1. app.py +36 -30
app.py CHANGED
@@ -1,15 +1,8 @@
1
- import os, subprocess, sys
2
- os.system("git clone https://github.com/nostalgebraist/improved-diffusion.git && cd improved-diffusion && git fetch origin nbar-space && git checkout nbar-dev && pip install -e .")
3
- os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
4
- os.system("pip install einops==0.3.2")
5
- sys.path.append("improved-diffusion")
6
-
7
  import streamlit as st
8
 
9
  import numpy as np
10
  from PIL import Image
11
 
12
- import improved_diffusion.pipeline
13
  from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
14
 
15
  # constants
@@ -28,33 +21,46 @@ DIFFUSION_DEFAULTS = dict(
28
  yield_intermediates=True
29
  )
30
 
31
- if not os.path.exists(model_path_diffusion):
32
- model_tar_name = 'model.tar'
33
- model_tar_path = get_local_path_from_huggingface_cdn(
34
- HF_REPO_NAME_DIFFUSION, model_tar_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  )
36
- subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
37
-
38
- checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
39
- config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
40
-
41
- checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
42
- config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
43
 
44
- # load
45
- sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
46
- checkpoint_path=checkpoint_path_sres1,
47
- config_path=config_path_sres1,
48
- timestep_respacing=timestep_respacing_sres1
49
- )
50
 
51
- sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
52
- checkpoint_path=checkpoint_path_sres2,
53
- config_path=config_path_sres2,
54
- timestep_respacing=timestep_respacing_sres2
55
- )
56
 
57
- pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
58
 
59
  def handler(text, ts1, ts2, gs1):
60
  # a = np.random.randint(0, 255, (128, 128, 3)).astype(np.uint8)
 
 
 
 
 
 
 
1
  import streamlit as st
2
 
3
  import numpy as np
4
  from PIL import Image
5
 
 
6
  from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
7
 
8
  # constants
 
21
  yield_intermediates=True
22
  )
23
 
24
+ @st.cache
25
+ def setup():
26
+ import os, subprocess, sys
27
+ os.system("git clone https://github.com/nostalgebraist/improved-diffusion.git && cd improved-diffusion && git fetch origin nbar-space && git checkout nbar-dev && pip install -e .")
28
+ os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
29
+ os.system("pip install einops==0.3.2")
30
+ sys.path.append("improved-diffusion")
31
+
32
+ import improved_diffusion.pipeline
33
+
34
+ if not os.path.exists(model_path_diffusion):
35
+ model_tar_name = 'model.tar'
36
+ model_tar_path = get_local_path_from_huggingface_cdn(
37
+ HF_REPO_NAME_DIFFUSION, model_tar_name
38
+ )
39
+ subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
40
+
41
+ checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
42
+ config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
43
+
44
+ checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
45
+ config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
46
+
47
+ # load
48
+ sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
49
+ checkpoint_path=checkpoint_path_sres1,
50
+ config_path=config_path_sres1,
51
+ timestep_respacing=timestep_respacing_sres1
52
  )
 
 
 
 
 
 
 
53
 
54
+ sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
55
+ checkpoint_path=checkpoint_path_sres2,
56
+ config_path=config_path_sres2,
57
+ timestep_respacing=timestep_respacing_sres2
58
+ )
 
59
 
60
+ pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
61
+ return pipeline
 
 
 
62
 
63
+ pipeline = setup()
64
 
65
  def handler(text, ts1, ts2, gs1):
66
  # a = np.random.randint(0, 255, (128, 128, 3)).astype(np.uint8)