Adapter commited on
Commit
5d485d0
·
1 Parent(s): 334ae58
Files changed (1) hide show
  1. demo/model.py +4 -4
demo/model.py CHANGED
@@ -135,8 +135,8 @@ class Model_all:
135
 
136
  # sketch part
137
  self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
138
- use_conv=False)#.to(device)
139
- # self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
140
  self.model_edge = pidinet().to(device)
141
  self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
142
 
@@ -144,8 +144,8 @@ class Model_all:
144
  self.model_seger = seger().to(device)
145
  self.model_seger.eval()
146
  self.coler = Colorize(n=182)
147
- self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False)#.to(device)
148
- # self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
149
  self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
150
 
151
  # depth part
 
135
 
136
  # sketch part
137
  self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
138
+ use_conv=False).to(device)
139
+ self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
140
  self.model_edge = pidinet().to(device)
141
  self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
142
 
 
144
  self.model_seger = seger().to(device)
145
  self.model_seger.eval()
146
  self.coler = Colorize(n=182)
147
+ self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
148
+ self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
149
  self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
150
 
151
  # depth part