Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -26,6 +26,8 @@ try_cuda = not args.disable_cuda
|
|
26 |
torch.inference_mode()
|
27 |
torch.no_grad()
|
28 |
|
|
|
|
|
29 |
# Load segmentation models
|
30 |
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
|
31 |
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
@@ -67,7 +69,6 @@ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
|
|
67 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
68 |
pipe = load_diffusion_pipeline()
|
69 |
|
70 |
-
device = get_device(try_cuda=try_cuda)
|
71 |
segmentation_model = segmentation_model.to(device)
|
72 |
pipe = pipe.to(device)
|
73 |
if args.attention_slicing:
|
|
|
26 |
torch.inference_mode()
|
27 |
torch.no_grad()
|
28 |
|
29 |
+
device = get_device(try_cuda=try_cuda)
|
30 |
+
|
31 |
# Load segmentation models
|
32 |
def load_segmentation_models(model_name: str = 'facebook/detr-resnet-50-panoptic'):
|
33 |
feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
|
|
69 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
70 |
pipe = load_diffusion_pipeline()
|
71 |
|
|
|
72 |
segmentation_model = segmentation_model.to(device)
|
73 |
pipe = pipe.to(device)
|
74 |
if args.attention_slicing:
|