nostalgebraist commited on
Commit
7943576
·
1 Parent(s): e4b5feb
Files changed (1) hide show
  1. app.py +42 -43
app.py CHANGED
@@ -26,54 +26,53 @@ def setup():
26
  os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
27
  os.system("pip install einops==0.3.2")
28
  sys.path.append("improved-diffusion")
29
- from improved_diffusion import pipeline
30
- # import improved_diffusion.pipeline
31
- # from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
32
- #
33
- # if not os.path.exists(model_path_diffusion):
34
- # model_tar_name = 'model.tar'
35
- # model_tar_path = get_local_path_from_huggingface_cdn(
36
- # HF_REPO_NAME_DIFFUSION, model_tar_name
37
- # )
38
- # subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
39
- #
40
- # checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
41
- # config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
42
- #
43
- # checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
44
- # config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
45
- #
46
- # # load
47
- # sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
48
- # checkpoint_path=checkpoint_path_sres1,
49
- # config_path=config_path_sres1,
50
- # timestep_respacing=timestep_respacing_sres1
51
- # )
52
- #
53
- # sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
54
- # checkpoint_path=checkpoint_path_sres2,
55
- # config_path=config_path_sres2,
56
- # timestep_respacing=timestep_respacing_sres2
57
- # )
58
- #
59
- # pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
60
  return pipeline
61
 
62
  pipeline = setup()
63
 
64
  def handler(text, ts1, ts2, gs1):
65
- pass
66
- # # a = np.random.randint(0, 255, (128, 128, 3)).astype(np.uint8)
67
- # data = {'text': text[:380], 'guidance_scale': gs1}
68
- # args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
69
- # args.update(data)
70
- #
71
- # print(f"running: {args}")
72
- #
73
- # pipeline.base_model.set_timestep_respacing(str(ts1))
74
- # pipeline.super_res_model.set_timestep_respacing(str(ts2))
75
- #
76
- # return pipeline.sample(**args)
77
 
78
 
79
  text = st.text_area('asdf')
 
26
  os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
27
  os.system("pip install einops==0.3.2")
28
  sys.path.append("improved-diffusion")
29
+
30
+ import improved_diffusion.pipeline
31
+ from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
32
+
33
+ if not os.path.exists(model_path_diffusion):
34
+ model_tar_name = 'model.tar'
35
+ model_tar_path = get_local_path_from_huggingface_cdn(
36
+ HF_REPO_NAME_DIFFUSION, model_tar_name
37
+ )
38
+ subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
39
+
40
+ checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
41
+ config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
42
+
43
+ checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
44
+ config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
45
+
46
+ # load
47
+ sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
48
+ checkpoint_path=checkpoint_path_sres1,
49
+ config_path=config_path_sres1,
50
+ timestep_respacing=timestep_respacing_sres1
51
+ )
52
+
53
+ sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
54
+ checkpoint_path=checkpoint_path_sres2,
55
+ config_path=config_path_sres2,
56
+ timestep_respacing=timestep_respacing_sres2
57
+ )
58
+
59
+ pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
60
  return pipeline
61
 
62
  pipeline = setup()
63
 
64
  def handler(text, ts1, ts2, gs1):
65
+ # a = np.random.randint(0, 255, (128, 128, 3)).astype(np.uint8)
66
+ data = {'text': text[:380], 'guidance_scale': gs1}
67
+ args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
68
+ args.update(data)
69
+
70
+ print(f"running: {args}")
71
+
72
+ pipeline.base_model.set_timestep_respacing(str(ts1))
73
+ pipeline.super_res_model.set_timestep_respacing(str(ts2))
74
+
75
+ return pipeline.sample(**args)
 
76
 
77
 
78
  text = st.text_area('asdf')