Spaces:
Runtime error
Runtime error
Update gradio_utils/utils.py
Browse files- gradio_utils/utils.py +4 -4
gradio_utils/utils.py
CHANGED
|
@@ -158,10 +158,10 @@ def process(query_img, state,
|
|
| 158 |
torch.tensor(target_weight_s).float()[None])
|
| 159 |
|
| 160 |
data = {
|
| 161 |
-
'img_s': [support_img],
|
| 162 |
-
'img_q': q_img,
|
| 163 |
-
'target_s': [target_s],
|
| 164 |
-
'target_weight_s': [target_weight_s],
|
| 165 |
'target_q': None,
|
| 166 |
'target_weight_q': None,
|
| 167 |
'return_loss': False,
|
|
|
|
| 158 |
torch.tensor(target_weight_s).float()[None])
|
| 159 |
|
| 160 |
data = {
|
| 161 |
+
'img_s': [support_img.to(device)],
|
| 162 |
+
'img_q': q_img.to(device),
|
| 163 |
+
'target_s': [target_s.to(device)],
|
| 164 |
+
'target_weight_s': [target_weight_s.to(device)],
|
| 165 |
'target_q': None,
|
| 166 |
'target_weight_q': None,
|
| 167 |
'return_loss': False,
|