orhir commited on
Commit
2e96cb8
·
verified ·
1 Parent(s): cfd2d4a

Update gradio_utils/utils.py

Browse files
Files changed (1) hide show
  1. gradio_utils/utils.py +4 -4
gradio_utils/utils.py CHANGED
@@ -158,10 +158,10 @@ def process(query_img, state,
158
  torch.tensor(target_weight_s).float()[None])
159
 
160
  data = {
161
- 'img_s': [support_img],
162
- 'img_q': q_img,
163
- 'target_s': [target_s],
164
- 'target_weight_s': [target_weight_s],
165
  'target_q': None,
166
  'target_weight_q': None,
167
  'return_loss': False,
 
158
  torch.tensor(target_weight_s).float()[None])
159
 
160
  data = {
161
+ 'img_s': [support_img.to(device)],
162
+ 'img_q': q_img.to(device),
163
+ 'target_s': [target_s.to(device)],
164
+ 'target_weight_s': [target_weight_s.to(device)],
165
  'target_q': None,
166
  'target_weight_q': None,
167
  'return_loss': False,