Spaces:
Sleeping
Sleeping
Update my_model/object_detection.py
Browse files
my_model/object_detection.py
CHANGED
|
@@ -30,6 +30,7 @@ class ObjectDetector:
|
|
| 30 |
self.model = None
|
| 31 |
self.processor = None
|
| 32 |
self.model_name = None
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
|
|
@@ -64,8 +65,8 @@ class ObjectDetector:
|
|
| 64 |
|
| 65 |
try:
|
| 66 |
model_path = get_model_path('deformable-detr-detic')
|
| 67 |
-
self.processor = AutoImageProcessor.from_pretrained(model_path)
|
| 68 |
-
self.model = AutoModelForObjectDetection.from_pretrained(model_path)
|
| 69 |
except Exception as e:
|
| 70 |
print(f"Error loading Detic model: {e}")
|
| 71 |
raise
|
|
@@ -83,9 +84,9 @@ class ObjectDetector:
|
|
| 83 |
try:
|
| 84 |
model_path = get_model_path ('yolov5')
|
| 85 |
if model_path and os.path.exists(model_path):
|
| 86 |
-
self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local')
|
| 87 |
else:
|
| 88 |
-
self.model = torch.hub.load('ultralytics/yolov5', model_version, pretrained=pretrained)
|
| 89 |
except Exception as e:
|
| 90 |
print(f"Error loading YOLOv5 model: {e}")
|
| 91 |
raise
|
|
|
|
| 30 |
self.model = None
|
| 31 |
self.processor = None
|
| 32 |
self.model_name = None
|
| 33 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 34 |
|
| 35 |
|
| 36 |
def load_model(self, model_name='detic', pretrained=True, model_version='yolov5s'):
|
|
|
|
| 65 |
|
| 66 |
try:
|
| 67 |
model_path = get_model_path('deformable-detr-detic')
|
| 68 |
+
self.processor = AutoImageProcessor.from_pretrained(model_path, device_map = self.device)
|
| 69 |
+
self.model = AutoModelForObjectDetection.from_pretrained(model_path, device_map = self.device)
|
| 70 |
except Exception as e:
|
| 71 |
print(f"Error loading Detic model: {e}")
|
| 72 |
raise
|
|
|
|
| 84 |
try:
|
| 85 |
model_path = get_model_path ('yolov5')
|
| 86 |
if model_path and os.path.exists(model_path):
|
| 87 |
+
self.model = torch.hub.load(model_path, model_version, pretrained=pretrained, source='local', device_map = self.device)
|
| 88 |
else:
|
| 89 |
+
self.model = torch.hub.load('ultralytics/yolov5', model_version, pretrained=pretrained, device_map = self.device)
|
| 90 |
except Exception as e:
|
| 91 |
print(f"Error loading YOLOv5 model: {e}")
|
| 92 |
raise
|