Include the pathlib WindowsPath = PosixPath

#1
by phonghaitran - opened
.DS_Store DELETED
Binary file (6.15 kB)
 
.gitignore DELETED
File without changes
app.py CHANGED
@@ -1,433 +1,38 @@
1
- from PIL import Image, ImageDraw
 
2
 
3
- # Import the model components from unet directory
4
- from unet.unet_model import UNet
5
-
6
- import streamlit as st
7
- import plotly.express as px
8
- import pandas as pd
9
  import numpy as np
10
- import torchvision.transforms as T
11
-
12
  import torch
13
- import pathlib
14
- import io
15
  import cv2
16
- import tempfile
17
 
18
- # Adjust Path for Local Repository
19
  pathlib.WindowsPath = pathlib.PosixPath
20
 
21
- st.title("Smart city rubbish detection Web Application")
22
-
23
- def yolo():
24
- st.markdown(
25
- "<h1 style='text-align: center; font-size: 36px;'>Yolo object detection</h1>",
26
- unsafe_allow_html=True
27
- )
28
- st.markdown(
29
- "<h2 style='text-align: center; font-size: 30px;'>Using Yolov5</h2>",
30
- unsafe_allow_html=True
31
- )
32
-
33
- # Define the available labels
34
- default_sub_classes = [
35
- "container",
36
- "waste-paper",
37
- "plant",
38
- "transportation",
39
- "kitchenware",
40
- "rubbish bag",
41
- "chair",
42
- "wood",
43
- "electronics good",
44
- "sofa",
45
- "scrap metal",
46
- "carton",
47
- "bag",
48
- "tarpaulin",
49
- "accessory",
50
- "rubble",
51
- "table",
52
- "board",
53
- "mattress",
54
- "beverage",
55
- "tyre",
56
- "nylon",
57
- "rack",
58
- "styrofoam",
59
- "clothes",
60
- "toy",
61
- "furniture",
62
- "trolley",
63
- "carpet",
64
- "plastic cup"
65
- ]
66
-
67
- # Initialize session state for video processing
68
- if 'video_processed' not in st.session_state:
69
- st.session_state.video_processed = False
70
- st.session_state.output_video_path = None
71
- st.session_state.detections_summary = None
72
-
73
- # Cache the model loading to prevent repeated loads
74
- @st.cache_resource
75
- def load_model():
76
- model = torch.hub.load('./yolov5', 'custom', path='./model/yolo/best.pt', source='local', force_reload=False)
77
- return model
78
-
79
- model = load_model()
80
-
81
- # Retrieve model class names
82
- model_class_names = model.names # Dictionary {index: class_name}
83
-
84
- # Function to map class names to indices (case-insensitive)
85
- def get_class_indices(class_list):
86
- indices = []
87
- not_found = []
88
- for cls in class_list:
89
- found = False
90
- for index, name in model_class_names.items():
91
- if name.lower() == cls.lower():
92
- indices.append(index)
93
- found = True
94
- break
95
- if not found:
96
- not_found.append(cls)
97
- return indices, not_found
98
-
99
- # Function to annotate images
100
- def annotate_image(frame, results):
101
- results.render() # Updates results.ims with the annotated images
102
- annotated_frame = results.ims[0] # Get the first (and only) image
103
- return annotated_frame
104
-
105
- # Inform the user about the available labels
106
- st.markdown("### Available Classes:")
107
- st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**")
108
-
109
- # Inform the user about the default detection
110
- st.info("By default, the application will detect **rubbish** only.")
111
-
112
- # User input for classes, separated by commas (optional)
113
- custom_classes_input = st.text_input(
114
- "Enter classes (comma-separated) or type 'all' to detect everything:",
115
- ""
116
- )
117
-
118
- # Retrieve all model classes
119
- all_model_classes = list(model_class_names.values())
120
-
121
- # Determine classes to use based on user input
122
- if custom_classes_input.strip() == "":
123
- # No input provided; use only 'rubbish'
124
- selected_classes = ['rubbish']
125
- st.info("No classes entered. Using default class: **rubbish**.")
126
- elif custom_classes_input.strip().lower() == "all":
127
- # User chose to detect all classes
128
- selected_classes = all_model_classes
129
- st.info("Detecting **all** available classes.")
130
- else:
131
- # User provided specific classes
132
- # Split the input string into a list of classes and remove any extra whitespace
133
- input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()]
134
- # Ensure 'rubbish' is included
135
- if 'rubbish' not in [cls.lower() for cls in input_classes]:
136
- selected_classes = input_classes + ['rubbish']
137
- st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)")
138
- else:
139
- selected_classes = input_classes
140
- st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**")
141
-
142
- # Map selected class names to their indices
143
- selected_class_indices, not_found_classes = get_class_indices(selected_classes)
144
-
145
- if not_found_classes:
146
- st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**")
147
-
148
- # Proceed only if there are valid classes to detect
149
- if selected_class_indices:
150
- # Set the classes for the model
151
- model.classes = selected_class_indices
152
-
153
- # --------------------- Image Upload and Processing ---------------------
154
- st.header("Image Object Detection")
155
-
156
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload")
157
-
158
- if uploaded_file is not None:
159
- try:
160
- # Convert the file to a PIL image
161
- image = Image.open(uploaded_file).convert('RGB')
162
- st.image(image, caption="Uploaded Image", use_column_width=True)
163
- st.write("Processing...")
164
-
165
- # Perform inference
166
- results = model(image)
167
-
168
- # Extract DataFrame from results
169
- results_df = results.pandas().xyxy[0]
170
-
171
- # Filter results to include only selected classes
172
- filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])]
173
-
174
- if filtered_results.empty:
175
- st.warning("No objects detected for the selected classes.")
176
- else:
177
- # Display filtered results
178
- st.write("### Detection Results")
179
- st.dataframe(filtered_results)
180
-
181
- # Annotate the image
182
- annotated_image = annotate_image(np.array(image), results)
183
-
184
- # Convert annotated image back to PIL format
185
- annotated_pil = Image.fromarray(annotated_image)
186
-
187
- # Display annotated image
188
- st.image(annotated_pil, caption="Annotated Image", use_column_width=True)
189
-
190
- # Convert annotated image to bytes
191
- img_byte_arr = io.BytesIO()
192
- annotated_pil.save(img_byte_arr, format='PNG')
193
- img_byte_arr = img_byte_arr.getvalue()
194
-
195
- # Add download button
196
- st.download_button(
197
- label="Download Annotated Image",
198
- data=img_byte_arr,
199
- file_name='annotated_image.png',
200
- mime='image/png'
201
- )
202
- except Exception as e:
203
- st.error(f"An error occurred during image processing: {e}")
204
-
205
- # --------------------- Video Upload and Processing ---------------------
206
- st.header("Video Object Detection")
207
-
208
- uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload")
209
-
210
- if uploaded_video is not None:
211
- # Check if the uploaded video is different from the previously processed one
212
- # Check if the uploaded video first time
213
- if st.session_state.get("uploaded_video_name") is None:
214
- st.session_state.uploaded_video_name = uploaded_video.name
215
- print("First time uploaded video" +st.session_state.uploaded_video_name)
216
- elif st.session_state.uploaded_video_name != uploaded_video.name:
217
- st.session_state.uploaded_video_name = uploaded_video.name
218
- print("Another time uploaded video" +st.session_state.uploaded_video_name)
219
- st.session_state.video_processed = False
220
- st.session_state.output_video_path = None
221
- st.session_state.detections_summary = None
222
- print("New uploaded video")
223
-
224
- # Reset session state if video upload is removed
225
- if uploaded_video is None and st.session_state.video_processed:
226
- st.session_state.video_processed = False
227
- st.session_state.output_video_path = None
228
- st.session_state.detections_summary = None
229
- st.warning("Video upload has been cleared. You can upload a new video for processing.")
230
-
231
- if uploaded_video:
232
- if not st.session_state.video_processed:
233
- try:
234
- with st.spinner("Processing video..."):
235
- # Save uploaded video to a temporary file
236
- tfile = tempfile.NamedTemporaryFile(delete=False)
237
- tfile.write(uploaded_video.read())
238
- tfile.close()
239
-
240
- # Open the video file
241
- video_cap = cv2.VideoCapture(tfile.name)
242
- stframe = st.empty() # Placeholder for displaying video frames
243
-
244
- # Initialize VideoWriter for saving the output video
245
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
246
- fps = video_cap.get(cv2.CAP_PROP_FPS)
247
- width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
248
- height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
249
- output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
250
- out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
251
-
252
- frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
253
- progress_bar = st.progress(0)
254
-
255
- # Initialize list to collect all detections
256
- all_detections = []
257
-
258
- for frame_num in range(frame_count):
259
- ret, frame = video_cap.read() # Read a frame from the video
260
- if not ret:
261
- break
262
-
263
- # Convert frame to RGB
264
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
265
-
266
- # Perform inference
267
- results = model(frame_rgb)
268
-
269
- # Extract DataFrame from results
270
- results_df = results.pandas().xyxy[0]
271
- results_df['frame_num'] = frame_num # Optional: Add frame number for reference
272
-
273
- # Append detections to the list
274
- if not results_df.empty:
275
- all_detections.append(results_df)
276
-
277
- # Annotate the frame with detections
278
- annotated_frame = annotate_image(frame_rgb, results)
279
-
280
- # Convert annotated frame back to BGR for VideoWriter
281
- annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
282
-
283
- # Write the annotated frame to the output video
284
- out.write(annotated_bgr)
285
-
286
- # Display the annotated frame in Streamlit
287
- stframe.image(annotated_frame, channels="RGB", use_column_width=True)
288
-
289
- # Update progress bar
290
- progress_percent = (frame_num + 1) / frame_count
291
- progress_bar.progress(progress_percent)
292
-
293
- video_cap.release() # Release the video capture object
294
- out.release() # Release the VideoWriter object
295
-
296
- # Save processed video path and detections summary to session state
297
- st.session_state.output_video_path = output_video_path
298
-
299
- if all_detections:
300
- # Concatenate all detections into a single DataFrame
301
- detections_df = pd.concat(all_detections, ignore_index=True)
302
-
303
- # Optional: Group by class name and count detections
304
- detections_summary = detections_df.groupby('name').size().reset_index(name='counts')
305
- st.session_state.detections_summary = detections_summary
306
- else:
307
- st.session_state.detections_summary = None
308
-
309
- # Mark video as processed
310
- st.session_state.video_processed = True
311
-
312
- # st.session_state.uploaded_video_name = uploaded_video.name
313
-
314
- st.success("Video processing complete!")
315
-
316
- except Exception as e:
317
- st.error(f"An error occurred during video processing: {e}")
318
-
319
- # Display download button and detection summary if processed
320
- if st.session_state.video_processed:
321
- try:
322
- # Create a download button for the annotated video
323
- with open(st.session_state.output_video_path, "rb") as video_file:
324
- st.download_button(
325
- label="Download Annotated Video",
326
- data=video_file,
327
- file_name="annotated_video.mp4",
328
- mime="video/mp4"
329
- )
330
-
331
- # Display detection table if there are detections
332
- if st.session_state.detections_summary is not None:
333
- detections_summary = st.session_state.detections_summary
334
-
335
- st.write("### Detection Summary")
336
- st.dataframe(detections_summary)
337
- else:
338
- st.warning("No objects detected in the video for the selected classes.")
339
- except Exception as e:
340
- st.error(f"An error occurred while preparing the download: {e}")
341
-
342
- # Optionally, display all available classes when 'all' is selected
343
- if custom_classes_input.strip().lower() == "all":
344
- st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}")
345
-
346
- # Unet model training configuration
347
-
348
- # Constants
349
- IMG_SIZE = 128 # Resize dimension for the input image
350
-
351
- # Load model function
352
- @st.cache_resource
353
- def load_model():
354
- model = UNet(n_channels=3, n_classes=32) # Adjust according to your model setup
355
- model.load_state_dict(torch.load("./model/unet/checkpoint_epoch5.pth", map_location="cpu", weights_only=True), strict=False)
356
- model.eval()
357
- return model
358
-
359
- # Function to preprocess the image
360
- def preprocess_image(image):
361
- transform = T.Compose([
362
- T.Resize((IMG_SIZE, IMG_SIZE)), # Resize to match model input size
363
- T.ToTensor(), # Convert to tensor
364
- ])
365
- image_tensor = transform(image).unsqueeze(0) # Add batch dimension
366
- return image_tensor
367
-
368
- # Function to postprocess the model output for display
369
- def postprocess_mask(mask):
370
- # Convert mask to a numpy array and scale to 0-255
371
- mask_np = mask.squeeze().cpu().numpy() # Remove batch and channel dimensions
372
- mask_np = (mask_np > 0.5).astype(np.uint8) * 255 # Binarize and scale to 0-255
373
- return mask_np
374
-
375
- def unet():
376
- try:
377
- # Load the model
378
- model = load_model()
379
-
380
- st.markdown(
381
- "<h1 style='text-align: center; font-size: 36px;'>Unet object detection</h1>",
382
- unsafe_allow_html=True
383
- )
384
- st.markdown(
385
- "<h2 style='text-align: center; font-size: 30px;'>Using Unet - Pytorch</h2>",
386
- unsafe_allow_html=True
387
- )
388
-
389
- # Display the file upload widget
390
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
391
- if uploaded_file is not None:
392
- st.write("Processing...")
393
- # Open and display the uploaded image
394
- image = Image.open(uploaded_file).convert("RGB")
395
- st.image(image, caption="Uploaded Image", use_column_width=True)
396
-
397
- # Preprocess the image
398
- input_tensor = preprocess_image(image)
399
-
400
- # Perform inference
401
- with torch.no_grad(): # Disable gradient calculation for inference
402
- output = model(input_tensor)
403
- prediction = torch.sigmoid(output) # Apply sigmoid to get probabilities
404
 
405
- # Post-process the mask for display
406
- mask = postprocess_mask(prediction[0, 0]) # Get the mask from the first batch item
407
 
408
- # Display the segmentation mask
409
- st.image(mask, caption="Segmentation Mask", use_column_width=True)
410
- except Exception as e:
411
- st.error(f"An error occurred in Unet: {e}")
412
 
413
- # Main page
414
- if 'model_selected' not in st.session_state:
415
- st.session_state.model_selected = None
 
 
416
 
417
- def main():
418
- # Radio button for model selection with consistent casing
419
- option = st.radio("Select Model:", ("Unet", "YOLO"))
420
 
421
- # Submit button to confirm selection
422
- if st.button("Choose"):
423
- st.session_state.model_selected = option
424
- st.success(f"Selected Model: {st.session_state.model_selected}")
425
 
426
- # Render the selected model's interface based on session state
427
- if st.session_state.model_selected == "Unet":
428
- unet()
429
- elif st.session_state.model_selected == "YOLO":
430
- yolo()
431
 
432
- if __name__ == "__main__":
433
- main()
 
1
+ from pathlib import Path
2
+ from PIL import Image
3
 
4
+ import pathlib
 
 
 
 
 
5
  import numpy as np
 
 
6
  import torch
7
+ import streamlit as st
 
8
  import cv2
 
9
 
10
+ #If you have linux (or deploying for linux) use:
11
  pathlib.WindowsPath = pathlib.PosixPath
12
 
13
+ # Load YOLOv5 model
14
+ model = torch.hub.load('./yolov5', 'custom', path='./yolo/best.pt', source='local', force_reload=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ st.title("YOLO Object Detection Web App")
 
17
 
18
+ # Upload image
19
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
20
 
21
+ if uploaded_file is not None:
22
+ # Convert the file to an OpenCV image
23
+ image = Image.open(uploaded_file)
24
+ st.image(image, caption="Uploaded Image", use_column_width=True)
25
+ st.write("Processing...")
26
 
27
+ # Convert the image to a format compatible with YOLO
28
+ image_np = np.array(image)
29
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
30
 
31
+ # Perform YOLO detection
32
+ results = model(image_cv)
 
 
33
 
34
+ # Render the results
35
+ detected_image = np.squeeze(results.render())
 
 
 
36
 
37
+ # Display result
38
+ st.image(detected_image, caption="Detected Image", use_column_width=True)
model/.DS_Store DELETED
Binary file (6.15 kB)
 
requirements.txt CHANGED
@@ -26,7 +26,6 @@ ultralytics>=8.2.34 # https://ultralytics.com
26
  # Plotting --------------------------------------------------------------------
27
  pandas>=1.1.4
28
  seaborn>=0.11.0
29
- plotly>=4.14.3
30
 
31
  # Export ----------------------------------------------------------------------
32
  # coremltools>=6.0 # CoreML export
 
26
  # Plotting --------------------------------------------------------------------
27
  pandas>=1.1.4
28
  seaborn>=0.11.0
 
29
 
30
  # Export ----------------------------------------------------------------------
31
  # coremltools>=6.0 # CoreML export
unet/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .unet_model import UNet
 
 
unet/__pycache__/__init__.cpython-312.pyc DELETED
Binary file (220 Bytes)
 
unet/__pycache__/unet_model.cpython-312.pyc DELETED
Binary file (2.21 kB)
 
unet/__pycache__/unet_parts.cpython-312.pyc DELETED
Binary file (4.46 kB)
 
{model/unet β†’ unet}/checkpoint_epoch5.pth RENAMED
File without changes
unet/unet_model.py DELETED
@@ -1,36 +0,0 @@
1
- """ Full assembly of the parts to form the complete network """
2
-
3
- from .unet_parts import *
4
-
5
-
6
- class UNet(nn.Module):
7
- def __init__(self, n_channels, n_classes, bilinear=False):
8
- super(UNet, self).__init__()
9
- self.n_channels = n_channels
10
- self.n_classes = n_classes
11
- self.bilinear = bilinear
12
-
13
- self.inc = DoubleConv(n_channels, 64)
14
- self.down1 = Down(64, 128)
15
- self.down2 = Down(128, 256)
16
- self.down3 = Down(256, 512)
17
- factor = 2 if bilinear else 1
18
- self.down4 = Down(512, 1024 // factor)
19
- self.up1 = Up(1024, 512 // factor, bilinear)
20
- self.up2 = Up(512, 256 // factor, bilinear)
21
- self.up3 = Up(256, 128 // factor, bilinear)
22
- self.up4 = Up(128, 64, bilinear)
23
- self.outc = OutConv(64, n_classes)
24
-
25
- def forward(self, x):
26
- x1 = self.inc(x)
27
- x2 = self.down1(x1)
28
- x3 = self.down2(x2)
29
- x4 = self.down3(x3)
30
- x5 = self.down4(x4)
31
- x = self.up1(x5, x4)
32
- x = self.up2(x, x3)
33
- x = self.up3(x, x2)
34
- x = self.up4(x, x1)
35
- logits = self.outc(x)
36
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
unet/unet_parts.py DELETED
@@ -1,77 +0,0 @@
1
- """ Parts of the U-Net model """
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
-
8
- class DoubleConv(nn.Module):
9
- """(convolution => [BN] => ReLU) * 2"""
10
-
11
- def __init__(self, in_channels, out_channels, mid_channels=None):
12
- super().__init__()
13
- if not mid_channels:
14
- mid_channels = out_channels
15
- self.double_conv = nn.Sequential(
16
- nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
17
- nn.BatchNorm2d(mid_channels),
18
- nn.ReLU(inplace=True),
19
- nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
20
- nn.BatchNorm2d(out_channels),
21
- nn.ReLU(inplace=True)
22
- )
23
-
24
- def forward(self, x):
25
- return self.double_conv(x)
26
-
27
-
28
- class Down(nn.Module):
29
- """Downscaling with maxpool then double conv"""
30
-
31
- def __init__(self, in_channels, out_channels):
32
- super().__init__()
33
- self.maxpool_conv = nn.Sequential(
34
- nn.MaxPool2d(2),
35
- DoubleConv(in_channels, out_channels)
36
- )
37
-
38
- def forward(self, x):
39
- return self.maxpool_conv(x)
40
-
41
-
42
- class Up(nn.Module):
43
- """Upscaling then double conv"""
44
-
45
- def __init__(self, in_channels, out_channels, bilinear=True):
46
- super().__init__()
47
-
48
- # if bilinear, use the normal convolutions to reduce the number of channels
49
- if bilinear:
50
- self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51
- self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52
- else:
53
- self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54
- self.conv = DoubleConv(in_channels, out_channels)
55
-
56
- def forward(self, x1, x2):
57
- x1 = self.up(x1)
58
- # input is CHW
59
- diffY = x2.size()[2] - x1.size()[2]
60
- diffX = x2.size()[3] - x1.size()[3]
61
-
62
- x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
63
- diffY // 2, diffY - diffY // 2])
64
- # if you have padding issues, see
65
- # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
66
- # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
67
- x = torch.cat([x2, x1], dim=1)
68
- return self.conv(x)
69
-
70
-
71
- class OutConv(nn.Module):
72
- def __init__(self, in_channels, out_channels):
73
- super(OutConv, self).__init__()
74
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
75
-
76
- def forward(self, x):
77
- return self.conv(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{model/yolo β†’ yolo}/best.pt RENAMED
File without changes
yolov5 CHANGED
@@ -1 +1 @@
1
- Subproject commit 2f74455adc74a587c9e9d5a6e45df880fce8ea3e
 
1
+ Subproject commit 24ee28010fbf597ec796e6e471429cde21040f90