tstone87 commited on
Commit
680b8e8
·
verified ·
1 Parent(s): f34aee4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -21
app.py CHANGED
@@ -18,11 +18,14 @@ def process_input(uploaded_file, youtube_link, image_url, sensitivity):
18
  Priority: YouTube link > Image URL > Uploaded file.
19
  The sensitivity slider value is passed as the confidence threshold.
20
 
21
- Returns a tuple of 4 items:
22
- 1. download_file_path (for gr.File)
23
- 2. image_result (for gr.Image) or None
24
- 3. video_result (for gr.Video) or None
25
- 4. status message
 
 
 
26
  """
27
  input_path = None
28
 
@@ -55,28 +58,68 @@ def process_input(uploaded_file, youtube_link, image_url, sensitivity):
55
  else:
56
  return None, None, None, "Please provide an input using one of the methods."
57
 
58
- try:
59
- results = model.predict(source=input_path, save=True, conf=sensitivity)
60
- except Exception as e:
61
- return None, None, None, f"Error running prediction: {e}"
62
 
63
  output_path = None
64
- try:
65
- if hasattr(results[0], "save_path"):
66
- output_path = results[0].save_path
67
- else:
68
- annotated = results[0].plot() # returns a numpy array
69
- output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
70
- cv2.imwrite(output_path, annotated)
71
- except Exception as e:
72
- return None, None, None, f"Error processing the file: {e}"
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  if ((youtube_link and youtube_link.strip()) or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
75
  os.remove(input_path)
76
 
77
- ext = os.path.splitext(output_path)[1].lower()
78
- video_exts = [".mp4", ".mov", ".avi", ".webm"]
79
- if ext in video_exts:
80
  image_result = None
81
  video_result = output_path
82
  else:
 
18
  Priority: YouTube link > Image URL > Uploaded file.
19
  The sensitivity slider value is passed as the confidence threshold.
20
 
21
+ For video files (mp4, mov, avi, webm), we use streaming mode to obtain annotated frames and encode them into a video.
22
+ For images, we use the normal prediction and either use the built‑in save_path or plot() method.
23
+
24
+ Returns a tuple:
25
+ - download_file_path (for gr.File)
26
+ - image_result (for gr.Image) or None
27
+ - video_result (for gr.Video) or None
28
+ - status message
29
  """
30
  input_path = None
31
 
 
58
  else:
59
  return None, None, None, "Please provide an input using one of the methods."
60
 
61
+ # Determine if input is a video (by extension).
62
+ ext_input = os.path.splitext(input_path)[1].lower()
63
+ video_exts = [".mp4", ".mov", ".avi", ".webm"]
 
64
 
65
  output_path = None
 
 
 
 
 
 
 
 
 
66
 
67
+ if ext_input in video_exts:
68
+ # Process video using streaming mode.
69
+ try:
70
+ # Open video to get properties.
71
+ cap = cv2.VideoCapture(input_path)
72
+ if not cap.isOpened():
73
+ return None, None, None, "Error opening video file."
74
+ fps = cap.get(cv2.CAP_PROP_FPS)
75
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
76
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
77
+ cap.release()
78
+
79
+ # Use streaming mode to process each frame.
80
+ frames = []
81
+ for result in model.predict(source=input_path, stream=True, conf=sensitivity):
82
+ # result.plot() returns an annotated frame (numpy array)
83
+ annotated_frame = result.plot()
84
+ frames.append(annotated_frame)
85
+ if not frames:
86
+ return None, None, None, "No detections were returned from video streaming."
87
+ # Write frames to a temporary video file.
88
+ temp_video_path = os.path.join(tempfile.gettempdir(), "annotated_video.mp4")
89
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
90
+ out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height))
91
+ for frame in frames:
92
+ out.write(frame)
93
+ out.release()
94
+ output_path = temp_video_path
95
+ except Exception as e:
96
+ return None, None, None, f"Error processing video: {e}"
97
+ else:
98
+ # Process as an image.
99
+ try:
100
+ results = model.predict(source=input_path, save=True, conf=sensitivity)
101
+ except Exception as e:
102
+ return None, None, None, f"Error running prediction: {e}"
103
+
104
+ try:
105
+ if not results or len(results) == 0:
106
+ return None, None, None, "No detections were returned."
107
+ if hasattr(results[0], "save_path"):
108
+ output_path = results[0].save_path
109
+ else:
110
+ annotated = results[0].plot() # returns a numpy array
111
+ output_path = os.path.join(tempfile.gettempdir(), "annotated.jpg")
112
+ cv2.imwrite(output_path, annotated)
113
+ except Exception as e:
114
+ return None, None, None, f"Error processing the file: {e}"
115
+
116
+ # Clean up temporary input if downloaded.
117
  if ((youtube_link and youtube_link.strip()) or (image_url and image_url.strip())) and input_path and os.path.exists(input_path):
118
  os.remove(input_path)
119
 
120
+ # Set outputs based on output file extension.
121
+ ext_output = os.path.splitext(output_path)[1].lower()
122
+ if ext_output in video_exts:
123
  image_result = None
124
  video_result = output_path
125
  else: