HongFangzhou commited on
Commit
fe189ee
Β·
1 Parent(s): 0a91ae4

update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -52,24 +52,31 @@ def add_text(rgb, caption):
52
  return rgb
53
 
54
  config = "3DTopia/configs/default.yaml"
55
- local_ckpt = "checkpoints/3dtopia_diffusion_state_dict.ckpt"
56
  if os.path.exists(local_ckpt):
57
  ckpt = local_ckpt
58
  else:
59
  ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors")
 
60
  configs = OmegaConf.load(config)
61
  os.makedirs("tmp", exist_ok=True)
 
62
 
63
  if ckpt.endswith(".ckpt"):
64
  model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
65
  elif ckpt.endswith(".safetensors"):
 
66
  model = get_obj_from_str(configs.model["target"])(**configs.model.params)
 
67
  model_ckpt = load_file(ckpt)
 
68
  model.load_state_dict(model_ckpt)
 
69
  else:
70
  raise NotImplementedError
71
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
72
  model = model.to(device)
 
73
  sampler = DDIMSampler(model)
74
 
75
  img_size = configs.model.params.unet_config.params.image_size
@@ -106,6 +113,7 @@ for p in poses_fname:
106
  batch_rays_list.append(batch_rays)
107
  batch_rays_list = torch.stack(batch_rays_list, 0)
108
 
 
109
  def marching_cube(b, text, global_info):
110
  # prepare volumn for marching cube
111
  res = 128
 
52
  return rgb
53
 
54
  config = "3DTopia/configs/default.yaml"
55
+ local_ckpt = "3DTopia/checkpoints/3dtopia_diffusion_state_dict.ckpt"
56
  if os.path.exists(local_ckpt):
57
  ckpt = local_ckpt
58
  else:
59
  ckpt = hf_hub_download(repo_id="hongfz16/3DTopia", filename="model.safetensors")
60
+ print("download finish")
61
  configs = OmegaConf.load(config)
62
  os.makedirs("tmp", exist_ok=True)
63
+ print("download finish")
64
 
65
  if ckpt.endswith(".ckpt"):
66
  model = get_obj_from_str(configs.model["target"]).load_from_checkpoint(ckpt, map_location='cpu', strict=False, **configs.model.params)
67
  elif ckpt.endswith(".safetensors"):
68
+ print("download finish")
69
  model = get_obj_from_str(configs.model["target"])(**configs.model.params)
70
+ print("download finish")
71
  model_ckpt = load_file(ckpt)
72
+ print("download finish")
73
  model.load_state_dict(model_ckpt)
74
+ print("download finish")
75
  else:
76
  raise NotImplementedError
77
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
78
  model = model.to(device)
79
+ print("download finish")
80
  sampler = DDIMSampler(model)
81
 
82
  img_size = configs.model.params.unet_config.params.image_size
 
113
  batch_rays_list.append(batch_rays)
114
  batch_rays_list = torch.stack(batch_rays_list, 0)
115
 
116
+ print("download finish")
117
  def marching_cube(b, text, global_info):
118
  # prepare volumn for marching cube
119
  res = 128