nostalgebraist commited on
Commit
f17fc25
·
1 Parent(s): a133cc3
Files changed (1) hide show
  1. app.py +42 -42
app.py CHANGED
@@ -26,53 +26,53 @@ def setup():
26
  os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
27
  os.system("pip install einops==0.3.2")
28
  sys.path.append("improved-diffusion")
29
-
30
- import improved_diffusion.pipeline
31
- from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
32
-
33
- if not os.path.exists(model_path_diffusion):
34
- model_tar_name = 'model.tar'
35
- model_tar_path = get_local_path_from_huggingface_cdn(
36
- HF_REPO_NAME_DIFFUSION, model_tar_name
37
- )
38
- subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
39
-
40
- checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
41
- config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
42
-
43
- checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
44
- config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
45
-
46
- # load
47
- sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
48
- checkpoint_path=checkpoint_path_sres1,
49
- config_path=config_path_sres1,
50
- timestep_respacing=timestep_respacing_sres1
51
- )
52
-
53
- sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
54
- checkpoint_path=checkpoint_path_sres2,
55
- config_path=config_path_sres2,
56
- timestep_respacing=timestep_respacing_sres2
57
- )
58
-
59
- pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
60
  return pipeline
61
 
62
  pipeline = setup()
63
 
64
  def handler(text, ts1, ts2, gs1):
65
- # a = np.random.randint(0, 255, (128, 128, 3)).astype(np.uint8)
66
- data = {'text': text[:380], 'guidance_scale': gs1}
67
- args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
68
- args.update(data)
69
-
70
- print(f"running: {args}")
71
-
72
- pipeline.base_model.set_timestep_respacing(str(ts1))
73
- pipeline.super_res_model.set_timestep_respacing(str(ts2))
74
-
75
- return pipeline.sample(**args)
76
 
77
 
78
  text = st.text_area('asdf')
 
26
  os.system("pip install tokenizers x-transformers==0.22.0 axial-positional-embedding")
27
  os.system("pip install einops==0.3.2")
28
  sys.path.append("improved-diffusion")
29
+ from improved_diffusion import pipeline
30
+ # import improved_diffusion.pipeline
31
+ # from transformer_utils.util.tfm_utils import get_local_path_from_huggingface_cdn
32
+ #
33
+ # if not os.path.exists(model_path_diffusion):
34
+ # model_tar_name = 'model.tar'
35
+ # model_tar_path = get_local_path_from_huggingface_cdn(
36
+ # HF_REPO_NAME_DIFFUSION, model_tar_name
37
+ # )
38
+ # subprocess.run(f"tar -xf {model_tar_path} && rm {model_tar_path}", shell=True)
39
+ #
40
+ # checkpoint_path_sres1 = os.path.join(model_path_diffusion, "sres1.pt")
41
+ # config_path_sres1 = os.path.join(model_path_diffusion, "config_sres1.json")
42
+ #
43
+ # checkpoint_path_sres2 = os.path.join(model_path_diffusion, "sres2.pt")
44
+ # config_path_sres2 = os.path.join(model_path_diffusion, "config_sres2.json")
45
+ #
46
+ # # load
47
+ # sampling_model_sres1 = improved_diffusion.pipeline.SamplingModel.from_config(
48
+ # checkpoint_path=checkpoint_path_sres1,
49
+ # config_path=config_path_sres1,
50
+ # timestep_respacing=timestep_respacing_sres1
51
+ # )
52
+ #
53
+ # sampling_model_sres2 = improved_diffusion.pipeline.SamplingModel.from_config(
54
+ # checkpoint_path=checkpoint_path_sres2,
55
+ # config_path=config_path_sres2,
56
+ # timestep_respacing=timestep_respacing_sres2
57
+ # )
58
+ #
59
+ # pipeline = improved_diffusion.pipeline.SamplingPipeline(sampling_model_sres1, sampling_model_sres2)
60
  return pipeline
61
 
62
  pipeline = setup()
63
 
64
  def handler(text, ts1, ts2, gs1):
65
+ # # a = np.random.randint(0, 255, (128, 128, 3)).astype(np.uint8)
66
+ # data = {'text': text[:380], 'guidance_scale': gs1}
67
+ # args = {k: v for k, v in DIFFUSION_DEFAULTS.items()}
68
+ # args.update(data)
69
+ #
70
+ # print(f"running: {args}")
71
+ #
72
+ # pipeline.base_model.set_timestep_respacing(str(ts1))
73
+ # pipeline.super_res_model.set_timestep_respacing(str(ts2))
74
+ #
75
+ # return pipeline.sample(**args)
76
 
77
 
78
  text = st.text_area('asdf')