mlbench123 commited on
Commit
7fb14f8
·
verified ·
1 Parent(s): 4441cfe

Upload 4 files

Browse files
Files changed (4) hide show
  1. api_server.py +525 -0
  2. scalingtestupdated.py +184 -0
  3. u2netp.pth +3 -0
  4. u2netp.py +525 -0
api_server.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from fastapi import FastAPI, HTTPException, UploadFile, File, Form
2
+ # from pydantic import BaseModel
3
+ # import numpy as np
4
+ # from PIL import Image
5
+ # import io, uuid, os, shutil, timeit
6
+ # from datetime import datetime
7
+ # from fastapi.staticfiles import StaticFiles
8
+ # from fastapi.middleware.cors import CORSMiddleware
9
+
10
+ # # import your three wrappers
11
+ # from app import predict_simple, predict_middle, predict_full
12
+
13
+ # app = FastAPI()
14
+
15
+ # # allow CORS if needed
16
+ # app.add_middleware(
17
+ # CORSMiddleware,
18
+ # allow_origins=["*"],
19
+ # allow_methods=["*"],
20
+ # allow_headers=["*"],
21
+ # )
22
+
23
+ # BASE_URL = "https://snapanddtraceapp-988917236820.us-central1.run.app"
24
+ # OUTPUT_DIR = os.path.abspath("./outputs")
25
+ # os.makedirs(OUTPUT_DIR, exist_ok=True)
26
+ # app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
27
+
28
+ # UPDATES_DIR = os.path.abspath("./updates")
29
+ # os.makedirs(UPDATES_DIR, exist_ok=True)
30
+ # app.mount("/updates", StaticFiles(directory=UPDATES_DIR), name="updates")
31
+
32
+
33
+ # def save_and_build_urls(
34
+ # session_id: str,
35
+ # output_image: np.ndarray,
36
+ # outlines: np.ndarray,
37
+ # dxf_path: str,
38
+ # mask: np.ndarray
39
+ # ):
40
+ # """Helper to save all four artifacts and return public URLs."""
41
+ # request_dir = os.path.join(OUTPUT_DIR, session_id)
42
+ # os.makedirs(request_dir, exist_ok=True)
43
+
44
+ # # filenames
45
+ # out_fn = "overlay.jpg"
46
+ # outlines_fn = "outlines.jpg"
47
+ # mask_fn = "mask.jpg"
48
+ # current_date = datetime.now().strftime("%d-%m-%Y")
49
+ # dxf_fn = f"out_{current_date}_{session_id}.dxf"
50
+
51
+ # # full paths
52
+ # out_path = os.path.join(request_dir, out_fn)
53
+ # outlines_path = os.path.join(request_dir, outlines_fn)
54
+ # mask_path = os.path.join(request_dir, mask_fn)
55
+ # new_dxf_path = os.path.join(request_dir, dxf_fn)
56
+
57
+ # # save images
58
+ # Image.fromarray(output_image).save(out_path)
59
+ # Image.fromarray(outlines).save(outlines_path)
60
+ # Image.fromarray(mask).save(mask_path)
61
+
62
+ # # copy dx file
63
+ # if os.path.exists(dxf_path):
64
+ # shutil.copy(dxf_path, new_dxf_path)
65
+ # else:
66
+ # # fallback if your DXF generator returns bytes or string
67
+ # with open(new_dxf_path, "wb") as f:
68
+ # if isinstance(dxf_path, (bytes, bytearray)):
69
+ # f.write(dxf_path)
70
+ # else:
71
+ # f.write(str(dxf_path).encode("utf-8"))
72
+
73
+ # # build URLs
74
+ # return {
75
+ # "output_image_url": f"{BASE_URL}/outputs/{session_id}/{out_fn}",
76
+ # "outlines_url": f"{BASE_URL}/outputs/{session_id}/{outlines_fn}",
77
+ # "mask_url": f"{BASE_URL}/outputs/{session_id}/{mask_fn}",
78
+ # "dxf_url": f"{BASE_URL}/outputs/{session_id}/{dxf_fn}",
79
+ # }
80
+
81
+
82
+ # @app.post("/predict1")
83
+ # async def predict1_api(
84
+ # file: UploadFile = File(...)
85
+ # ):
86
+ # """
87
+ # Simple predict: only image → overlay, outlines, mask, DXF
88
+ # """
89
+ # session_id = str(uuid.uuid4())
90
+ # try:
91
+ # img_bytes = await file.read()
92
+ # image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
93
+ # except Exception:
94
+ # raise HTTPException(400, "Invalid image upload")
95
+
96
+ # try:
97
+ # start = timeit.default_timer()
98
+ # out_img, outlines, dxf_path, mask = predict_simple(image)
99
+ # elapsed = timeit.default_timer() - start
100
+ # print(f"[{session_id}] predict1 in {elapsed:.2f}s")
101
+
102
+ # return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask)
103
+
104
+ # except Exception as e:
105
+ # raise HTTPException(500, f"predict1 failed: {e}")
106
+ # except ReferenceBoxNotDetectedError:
107
+ # raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
108
+ # except FingerCutOverlapError:
109
+ # raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.")
110
+
111
+
112
+ # @app.post("/predict2")
113
+ # async def predict2_api(
114
+ # file: UploadFile = File(...),
115
+ # enable_fillet: str = Form(..., regex="^(On|Off)$"),
116
+ # fillet_value_mm: float = Form(...)
117
+ # ):
118
+ # """
119
+ # Middle predict: image + fillet toggle + fillet value → overlay, outlines, mask, DXF
120
+ # """
121
+ # session_id = str(uuid.uuid4())
122
+ # try:
123
+ # img_bytes = await file.read()
124
+ # image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
125
+ # except Exception:
126
+ # raise HTTPException(400, "Invalid image upload")
127
+
128
+ # try:
129
+ # start = timeit.default_timer()
130
+ # out_img, outlines, dxf_path, mask = predict_middle(
131
+ # image, enable_fillet, fillet_value_mm
132
+ # )
133
+ # elapsed = timeit.default_timer() - start
134
+ # print(f"[{session_id}] predict2 in {elapsed:.2f}s")
135
+
136
+ # return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask)
137
+
138
+ # except Exception as e:
139
+ # raise HTTPException(500, f"predict2 failed: {e}")
140
+ # except ReferenceBoxNotDetectedError:
141
+ # raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
142
+ # except FingerCutOverlapError:
143
+ # raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.")
144
+
145
+ # @app.post("/predict3")
146
+ # async def predict3_api(
147
+ # file: UploadFile = File(...),
148
+ # enable_fillet: str = Form(..., regex="^(On|Off)$"),
149
+ # fillet_value_mm: float = Form(...),
150
+ # enable_finger_cut: str = Form(..., regex="^(On|Off)$")
151
+ # ):
152
+ # """
153
+ # Full predict: image + fillet toggle/value + finger-cut toggle → overlay, outlines, mask, DXF
154
+ # """
155
+ # session_id = str(uuid.uuid4())
156
+ # try:
157
+ # img_bytes = await file.read()
158
+ # image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
159
+ # except Exception:
160
+ # raise HTTPException(400, "Invalid image upload")
161
+
162
+ # try:
163
+ # start = timeit.default_timer()
164
+ # out_img, outlines, dxf_path, mask = predict_full(
165
+ # image, enable_fillet, fillet_value_mm, enable_finger_cut
166
+ # )
167
+ # elapsed = timeit.default_timer() - start
168
+ # print(f"[{session_id}] predict3 in {elapsed:.2f}s")
169
+
170
+ # return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask)
171
+
172
+ # except Exception as e:
173
+ # raise HTTPException(500, f"predict3 failed: {e}")
174
+ # except ReferenceBoxNotDetectedError:
175
+ # raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
176
+ # except FingerCutOverlapError:
177
+ # raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.")
178
+
179
+ # @app.post("/update")
180
+ # async def update_files(
181
+ # output_image: UploadFile = File(...),
182
+ # outlines_image: UploadFile = File(...),
183
+ # mask_image: UploadFile = File(...),
184
+ # dxf_file: UploadFile = File(...)
185
+ # ):
186
+ # session_id = str(uuid.uuid4())
187
+ # update_dir = os.path.join(UPDATES_DIR, session_id)
188
+ # os.makedirs(update_dir, exist_ok=True)
189
+
190
+ # try:
191
+ # upload_map = {
192
+ # "output_image": output_image,
193
+ # "outlines_image": outlines_image,
194
+ # "mask_image": mask_image,
195
+ # "dxf_file": dxf_file,
196
+ # }
197
+ # urls = {}
198
+ # for key, up in upload_map.items():
199
+ # fn = up.filename
200
+ # path = os.path.join(update_dir, fn)
201
+ # with open(path, "wb") as f:
202
+ # shutil.copyfileobj(up.file, f)
203
+ # urls[key] = f"{BASE_URL}/updates/{session_id}/{fn}"
204
+
205
+ # return {"session_id": session_id, "uploaded": urls}
206
+
207
+ # except Exception as e:
208
+ # raise HTTPException(500, f"Update failed: {e}")
209
+
210
+
211
+ # if __name__ == "__main__":
212
+ # import uvicorn
213
+ # port = int(os.environ.get("PORT", 8082))
214
+ # print(f"Starting FastAPI server on 0.0.0.0:{port}...")
215
+ # uvicorn.run(app, host="0.0.0.0", port=port)
216
+
217
+
218
+
219
+
220
+
221
+
222
+
223
+
224
+
225
+
226
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
227
+ from pydantic import BaseModel
228
+ import numpy as np
229
+ from PIL import Image
230
+ import io, uuid, os, shutil, timeit
231
+ from datetime import datetime
232
+ from fastapi.staticfiles import StaticFiles
233
+ from fastapi.middleware.cors import CORSMiddleware
234
+ from fastapi.responses import FileResponse
235
+
236
+ # import your three wrappers
237
+ from app import predict_simple, predict_middle, predict_full
238
+
239
+ from app import (
240
+ predict_simple, predict_middle, predict_full,
241
+ ReferenceBoxNotDetectedError,
242
+ FingerCutOverlapError
243
+ )
244
+
245
+
246
+ app = FastAPI()
247
+
248
+ # allow CORS if needed
249
+ app.add_middleware(
250
+ CORSMiddleware,
251
+ allow_origins=["*"],
252
+ allow_methods=["*"],
253
+ allow_headers=["*"],
254
+ )
255
+
256
+ BASE_URL = "https://snapanddtraceapp-988917236820.us-central1.run.app"
257
+
258
+ OUTPUT_DIR = os.path.abspath("./outputs")
259
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
260
+
261
+ UPDATES_DIR = os.path.abspath("./updates")
262
+ os.makedirs(UPDATES_DIR, exist_ok=True)
263
+
264
+ # Mount static directories with normal StaticFiles
265
+ app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs")
266
+ app.mount("/updates", StaticFiles(directory=UPDATES_DIR), name="updates")
267
+
268
+
269
+ def save_and_build_urls(
270
+ session_id: str,
271
+ output_image: np.ndarray,
272
+ outlines: np.ndarray,
273
+ dxf_path: str,
274
+ mask: np.ndarray,
275
+ endpoint_type: str,
276
+ fillet_value: float = None,
277
+ finger_cut: str = None
278
+ ):
279
+ """Helper to save all four artifacts and return public URLs."""
280
+ request_dir = os.path.join(OUTPUT_DIR, session_id)
281
+ os.makedirs(request_dir, exist_ok=True)
282
+
283
+ # filenames
284
+ out_fn = "overlay.jpg"
285
+ outlines_fn = "outlines.jpg"
286
+ mask_fn = "mask.jpg"
287
+
288
+ # Get current date
289
+ current_date = datetime.utcnow().strftime("%d-%m-%Y")
290
+
291
+
292
+ # Format fillet value with underscore instead of dot
293
+ fillet_str = f"{fillet_value:.2f}".replace(".", "_") if fillet_value is not None else None
294
+
295
+ # Determine DXF filename based on endpoint type
296
+ if endpoint_type == "predict1":
297
+ dxf_fn = f"DXF_{current_date}.dxf"
298
+ elif endpoint_type == "predict2":
299
+ dxf_fn = f"DXF_{current_date}.dxf"
300
+ elif endpoint_type == "predict3":
301
+ dxf_fn = f"DXF_{current_date}.dxf"
302
+
303
+ # full paths
304
+ out_path = os.path.join(request_dir, out_fn)
305
+ outlines_path = os.path.join(request_dir, outlines_fn)
306
+ mask_path = os.path.join(request_dir, mask_fn)
307
+ new_dxf_path = os.path.join(request_dir, dxf_fn)
308
+
309
+ # save images
310
+ Image.fromarray(output_image).save(out_path)
311
+ Image.fromarray(outlines).save(outlines_path)
312
+ Image.fromarray(mask).save(mask_path)
313
+
314
+ # copy dxf file
315
+ if os.path.exists(dxf_path):
316
+ shutil.copy(dxf_path, new_dxf_path)
317
+ else:
318
+ # fallback if your DXF generator returns bytes or string
319
+ with open(new_dxf_path, "wb") as f:
320
+ if isinstance(dxf_path, (bytes, bytearray)):
321
+ f.write(dxf_path)
322
+ else:
323
+ f.write(str(dxf_path).encode("utf-8"))
324
+
325
+ # build URLs with /download prefix for DXF
326
+ return {
327
+ "output_image_url": f"{BASE_URL}/outputs/{session_id}/{out_fn}",
328
+ "outlines_url": f"{BASE_URL}/outputs/{session_id}/{outlines_fn}",
329
+ "mask_url": f"{BASE_URL}/outputs/{session_id}/{mask_fn}",
330
+ "dxf_url": f"{BASE_URL}/download/{session_id}/{dxf_fn}", # Changed to use download endpoint
331
+ }
332
+
333
+ # Add new endpoint for downloading DXF files
334
+ @app.get("/download/{session_id}/{filename}")
335
+ async def download_file(session_id: str, filename: str):
336
+ file_path = os.path.join(OUTPUT_DIR, session_id, filename)
337
+ if not os.path.exists(file_path):
338
+ raise HTTPException(status_code=404, detail="File not found")
339
+
340
+ return FileResponse(
341
+ path=file_path,
342
+ filename=filename,
343
+ media_type="application/x-dxf",
344
+ headers={"Content-Disposition": f"attachment; filename={filename}"}
345
+ )
346
+
347
+
348
+ @app.post("/predict1")
349
+ async def predict1_api(
350
+ file: UploadFile = File(...)
351
+ ):
352
+ """
353
+ Simple predict: only image → overlay, outlines, mask, DXF
354
+ DXF naming format: DXF_DD-MM-YYYY.dxf
355
+ """
356
+ session_id = str(uuid.uuid4())
357
+ try:
358
+ img_bytes = await file.read()
359
+ image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
360
+ except Exception:
361
+ raise HTTPException(400, "Invalid image upload")
362
+
363
+ try:
364
+ start = timeit.default_timer()
365
+ out_img, outlines, dxf_path, mask = predict_simple(image)
366
+ elapsed = timeit.default_timer() - start
367
+ print(f"[{session_id}] predict1 in {elapsed:.2f}s")
368
+
369
+ return save_and_build_urls(
370
+ session_id=session_id,
371
+ output_image=out_img,
372
+ outlines=outlines,
373
+ dxf_path=dxf_path,
374
+ mask=mask,
375
+ endpoint_type="predict1"
376
+ )
377
+
378
+ except ReferenceBoxNotDetectedError:
379
+ raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
380
+ except FingerCutOverlapError:
381
+ raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.")
382
+ except HTTPException as e:
383
+ raise e
384
+ except Exception as e:
385
+ raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.")
386
+
387
+ @app.post("/predict2")
388
+ async def predict2_api(
389
+ file: UploadFile = File(...),
390
+ enable_fillet: str = Form(..., regex="^(On|Off)$"),
391
+ fillet_value_mm: float = Form(...)
392
+ ):
393
+ """
394
+ Middle predict: image + fillet toggle + fillet value → overlay, outlines, mask, DXF
395
+ DXF naming format: DXF_DD-MM-YYYY_fillet-value_mm.dxf
396
+ """
397
+ session_id = str(uuid.uuid4())
398
+ try:
399
+ img_bytes = await file.read()
400
+ image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
401
+ except Exception:
402
+ raise HTTPException(400, "Invalid image upload")
403
+
404
+ try:
405
+ start = timeit.default_timer()
406
+ out_img, outlines, dxf_path, mask = predict_middle(
407
+ image, enable_fillet, fillet_value_mm
408
+ )
409
+ elapsed = timeit.default_timer() - start
410
+ print(f"[{session_id}] predict2 in {elapsed:.2f}s")
411
+
412
+ return save_and_build_urls(
413
+ session_id=session_id,
414
+ output_image=out_img,
415
+ outlines=outlines,
416
+ dxf_path=dxf_path,
417
+ mask=mask,
418
+ endpoint_type="predict2",
419
+ fillet_value=fillet_value_mm
420
+ )
421
+
422
+ except ReferenceBoxNotDetectedError:
423
+ raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
424
+ except FingerCutOverlapError:
425
+ raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.")
426
+ except HTTPException as e:
427
+ raise e
428
+ except Exception as e:
429
+ raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.")
430
+
431
+
432
+ @app.post("/predict3")
433
+ async def predict3_api(
434
+ file: UploadFile = File(...),
435
+ enable_fillet: str = Form(..., regex="^(On|Off)$"),
436
+ fillet_value_mm: float = Form(...),
437
+ enable_finger_cut: str = Form(..., regex="^(On|Off)$")
438
+ ):
439
+ """
440
+ Full predict: image + fillet toggle/value + finger-cut toggle → overlay, outlines, mask, DXF
441
+ DXF naming format: DXF_DD-MM-YYYY_fillet-value_mm_fingercut-On|Off.dxf
442
+ """
443
+ session_id = str(uuid.uuid4())
444
+ try:
445
+ img_bytes = await file.read()
446
+ image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
447
+ except Exception:
448
+ raise HTTPException(400, "Invalid image upload")
449
+
450
+ try:
451
+ start = timeit.default_timer()
452
+ out_img, outlines, dxf_path, mask = predict_full(
453
+ image, enable_fillet, fillet_value_mm, enable_finger_cut
454
+ )
455
+ elapsed = timeit.default_timer() - start
456
+ print(f"[{session_id}] predict3 in {elapsed:.2f}s")
457
+
458
+ return save_and_build_urls(
459
+ session_id=session_id,
460
+ output_image=out_img,
461
+ outlines=outlines,
462
+ dxf_path=dxf_path,
463
+ mask=mask,
464
+ endpoint_type="predict3",
465
+ fillet_value=fillet_value_mm,
466
+ finger_cut=enable_finger_cut
467
+ )
468
+
469
+ except ReferenceBoxNotDetectedError:
470
+ raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.")
471
+ except FingerCutOverlapError:
472
+ raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.")
473
+ except HTTPException as e:
474
+ raise e
475
+ except Exception as e:
476
+ raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.")
477
+
478
+
479
+ @app.post("/update")
480
+ async def update_files(
481
+ output_image: UploadFile = File(...),
482
+ outlines_image: UploadFile = File(...),
483
+ mask_image: UploadFile = File(...),
484
+ dxf_file: UploadFile = File(...)
485
+ ):
486
+ session_id = str(uuid.uuid4())
487
+ update_dir = os.path.join(UPDATES_DIR, session_id)
488
+ os.makedirs(update_dir, exist_ok=True)
489
+
490
+ try:
491
+ upload_map = {
492
+ "output_image": output_image,
493
+ "outlines_image": outlines_image,
494
+ "mask_image": mask_image,
495
+ "dxf_file": dxf_file,
496
+ }
497
+ urls = {}
498
+ for key, up in upload_map.items():
499
+ fn = up.filename
500
+ path = os.path.join(update_dir, fn)
501
+ with open(path, "wb") as f:
502
+ shutil.copyfileobj(up.file, f)
503
+ urls[key] = f"{BASE_URL}/updates/{session_id}/{fn}"
504
+
505
+ return {"session_id": session_id, "uploaded": urls}
506
+
507
+ except Exception as e:
508
+ raise HTTPException(500, f"Update failed: {e}")
509
+
510
+
511
+ from fastapi import Response
512
+
513
+ @app.get("/health")
514
+ def health():
515
+ return Response(content="OK", status_code=200)
516
+
517
+
518
+ if __name__ == "__main__":
519
+ import uvicorn
520
+ port = int(os.environ.get("PORT", 8080))
521
+ print(f"Starting FastAPI server on 0.0.0.0:{port}...")
522
+ uvicorn.run(app, host="0.0.0.0", port=port)
523
+
524
+
525
+
scalingtestupdated.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import os
4
+ import argparse
5
+ from typing import Union
6
+ from matplotlib import pyplot as plt
7
+
8
+ class ScalingSquareDetector:
9
+ def __init__(self, feature_detector="ORB", debug=False):
10
+ """
11
+ Initialize the detector with the desired feature matching algorithm.
12
+ :param feature_detector: "ORB" or "SIFT" (default is "ORB").
13
+ :param debug: If True, saves intermediate images for debugging.
14
+ """
15
+ self.feature_detector = feature_detector
16
+ self.debug = debug
17
+ self.detector = self._initialize_detector()
18
+
19
+ def _initialize_detector(self):
20
+ """
21
+ Initialize the chosen feature detector.
22
+ :return: OpenCV detector object.
23
+ """
24
+ if self.feature_detector.upper() == "SIFT":
25
+ return cv2.SIFT_create()
26
+ elif self.feature_detector.upper() == "ORB":
27
+ return cv2.ORB_create()
28
+ else:
29
+ raise ValueError("Invalid feature detector. Choose 'ORB' or 'SIFT'.")
30
+
31
+ def find_scaling_square(
32
+ self, target_image, known_size_mm, roi_margin=30
33
+ ):
34
+ """
35
+ Detect the scaling square in the target image based on the reference image.
36
+ :param reference_image_path: Path to the reference image of the square.
37
+ :param target_image_path: Path to the target image containing the square.
38
+ :param known_size_mm: Physical size of the square in millimeters.
39
+ :param roi_margin: Margin to expand the ROI around the detected square (in pixels).
40
+ :return: Scaling factor (mm per pixel).
41
+ """
42
+ contours, _ = cv2.findContours(
43
+ target_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
44
+ )
45
+
46
+ if not contours:
47
+ raise ValueError("No contours found in the cropped ROI.")
48
+
49
+ # # Select the largest square-like contour
50
+ print(f"No of contours: {len(contours)}")
51
+ largest_square = None
52
+ # largest_square_area = 0
53
+ # for contour in contours:
54
+ # x_c, y_c, w_c, h_c = cv2.boundingRect(contour)
55
+ # aspect_ratio = w_c / float(h_c)
56
+ # if 0.9 <= aspect_ratio <= 1.1:
57
+ # peri = cv2.arcLength(contour, True)
58
+ # approx = cv2.approxPolyDP(contour, 0.02 * peri, True)
59
+ # if len(approx) == 4:
60
+ # area = cv2.contourArea(contour)
61
+ # if area > largest_square_area:
62
+ # largest_square = contour
63
+ # largest_square_area = area
64
+
65
+ for contour in contours:
66
+ largest_square = contour
67
+
68
+ # if largest_square is None:
69
+ # raise ValueError("No square-like contour found in the ROI.")
70
+
71
+ # Draw the largest contour on the original image
72
+ target_image_color = cv2.cvtColor(target_image, cv2.COLOR_GRAY2BGR)
73
+ cv2.drawContours(
74
+ target_image_color, largest_square, -1, (255, 0, 0), 3
75
+ )
76
+
77
+ # if self.debug:
78
+ cv2.imwrite("largest_contour.jpg", target_image_color)
79
+
80
+ # Calculate the bounding rectangle of the largest contour
81
+ x, y, w, h = cv2.boundingRect(largest_square)
82
+ square_width_px = w
83
+ square_height_px = h
84
+ print(f"Reference object size: {known_size_mm} mm")
85
+ print(f"width: {square_width_px} px")
86
+ print(f"height: {square_height_px} px")
87
+
88
+ # Calculate the scaling factor
89
+ avg_square_size_px = (square_width_px + square_height_px) / 2
90
+ print(f"avg square size: {avg_square_size_px} px")
91
+ scaling_factor = known_size_mm / avg_square_size_px # mm per pixel
92
+ print(f"scaling factor: {scaling_factor} mm per pixel")
93
+
94
+ return scaling_factor #, square_height_px, square_width_px, roi_binary
95
+
96
+ def draw_debug_images(self, output_folder):
97
+ """
98
+ Save debug images if enabled.
99
+ :param output_folder: Directory to save debug images.
100
+ """
101
+ if self.debug:
102
+ if not os.path.exists(output_folder):
103
+ os.makedirs(output_folder)
104
+ debug_images = ["largest_contour.jpg"]
105
+ for img_name in debug_images:
106
+ if os.path.exists(img_name):
107
+ os.rename(img_name, os.path.join(output_folder, img_name))
108
+
109
+
110
+ def calculate_scaling_factor(
111
+ target_image,
112
+ reference_obj_size_mm,
113
+ feature_detector="ORB",
114
+ debug=False,
115
+ roi_margin=30,
116
+ ):
117
+ # Initialize detector
118
+ detector = ScalingSquareDetector(feature_detector=feature_detector, debug=debug)
119
+
120
+ # Find scaling square and calculate scaling factor
121
+ scaling_factor = detector.find_scaling_square(
122
+ target_image=target_image,
123
+ known_size_mm=reference_obj_size_mm,
124
+ roi_margin=roi_margin,
125
+ )
126
+
127
+ # Save debug images
128
+ if debug:
129
+ detector.draw_debug_images("debug_outputs")
130
+
131
+ return scaling_factor
132
+
133
+
134
+ # Example usage:
135
+ if __name__ == "__main__":
136
+ import os
137
+ from PIL import Image
138
+ from ultralytics import YOLO
139
+ from app import yolo_detect, shrink_bbox
140
+ from ultralytics.utils.plotting import save_one_box
141
+
142
+ for idx, file in enumerate(os.listdir("./sample_images")):
143
+ img = np.array(Image.open(os.path.join("./sample_images", file)))
144
+ img = yolo_detect(img, ['box'])
145
+ model = YOLO("./best.pt")
146
+ res = model.predict(img, conf=0.6)
147
+
148
+ box_img = save_one_box(res[0].cpu().boxes.xyxy, im=res[0].orig_img, save=False)
149
+ # img = shrink_bbox(box_img, 1.20)
150
+ cv2.imwrite(f"./outputs/{idx}_{file}", box_img)
151
+
152
+ print("File: ",f"./outputs/{idx}_{file}")
153
+ try:
154
+
155
+ scaling_factor = calculate_scaling_factor(
156
+ target_image=box_img,
157
+ known_square_size_mm=20,
158
+ feature_detector="ORB",
159
+ debug=False,
160
+ roi_margin=90,
161
+ )
162
+ # cv2.imwrite(f"./outputs/{idx}_binary_{file}", roi_binary)
163
+
164
+ # Square size in mm
165
+ # square_size_mm = 12.7
166
+
167
+ # # Compute the calculated scaling factors and compare
168
+ # calculated_scaling_factor = square_size_mm / height_px
169
+ # discrepancy = abs(calculated_scaling_factor - scaling_factor)
170
+ # import pprint
171
+ # pprint.pprint({
172
+ # "height_px": height_px,
173
+ # "width_px": width_px,
174
+ # "given_scaling_factor": scaling_factor,
175
+ # "calculated_scaling_factor": calculated_scaling_factor,
176
+ # "discrepancy": discrepancy,
177
+ # })
178
+
179
+
180
+ print(f"Scaling Factor (mm per pixel): {scaling_factor:.6f}")
181
+ except Exception as e:
182
+ from traceback import print_exc
183
+ print(print_exc())
184
+ print(f"Error: {e}")
u2netp.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7567cde013fb64813973ce6e1ecc25a80c05c3ca7adbc5a54f3c3d90991b854
3
+ size 4683258
u2netp.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class REBNCONV(nn.Module):
6
+ def __init__(self,in_ch=3,out_ch=3,dirate=1):
7
+ super(REBNCONV,self).__init__()
8
+
9
+ self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)
10
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
11
+ self.relu_s1 = nn.ReLU(inplace=True)
12
+
13
+ def forward(self,x):
14
+
15
+ hx = x
16
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
17
+
18
+ return xout
19
+
20
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
21
+ def _upsample_like(src,tar):
22
+
23
+ src = F.upsample(src,size=tar.shape[2:],mode='bilinear')
24
+
25
+ return src
26
+
27
+
28
+ ### RSU-7 ###
29
+ class RSU7(nn.Module):#UNet07DRES(nn.Module):
30
+
31
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
32
+ super(RSU7,self).__init__()
33
+
34
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
35
+
36
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
37
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
38
+
39
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
40
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
41
+
42
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
43
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
44
+
45
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
46
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
47
+
48
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
49
+ self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
50
+
51
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)
52
+
53
+ self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)
54
+
55
+ self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
56
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
57
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
58
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
59
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
60
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
61
+
62
+ def forward(self,x):
63
+
64
+ hx = x
65
+ hxin = self.rebnconvin(hx)
66
+
67
+ hx1 = self.rebnconv1(hxin)
68
+ hx = self.pool1(hx1)
69
+
70
+ hx2 = self.rebnconv2(hx)
71
+ hx = self.pool2(hx2)
72
+
73
+ hx3 = self.rebnconv3(hx)
74
+ hx = self.pool3(hx3)
75
+
76
+ hx4 = self.rebnconv4(hx)
77
+ hx = self.pool4(hx4)
78
+
79
+ hx5 = self.rebnconv5(hx)
80
+ hx = self.pool5(hx5)
81
+
82
+ hx6 = self.rebnconv6(hx)
83
+
84
+ hx7 = self.rebnconv7(hx6)
85
+
86
+ hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))
87
+ hx6dup = _upsample_like(hx6d,hx5)
88
+
89
+ hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))
90
+ hx5dup = _upsample_like(hx5d,hx4)
91
+
92
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
93
+ hx4dup = _upsample_like(hx4d,hx3)
94
+
95
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
96
+ hx3dup = _upsample_like(hx3d,hx2)
97
+
98
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
99
+ hx2dup = _upsample_like(hx2d,hx1)
100
+
101
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
102
+
103
+ return hx1d + hxin
104
+
105
+ ### RSU-6 ###
106
+ class RSU6(nn.Module):#UNet06DRES(nn.Module):
107
+
108
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
109
+ super(RSU6,self).__init__()
110
+
111
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
112
+
113
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
114
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
115
+
116
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
117
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
118
+
119
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
120
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
121
+
122
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
123
+ self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
124
+
125
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
126
+
127
+ self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)
128
+
129
+ self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
130
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
131
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
132
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
133
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
134
+
135
+ def forward(self,x):
136
+
137
+ hx = x
138
+
139
+ hxin = self.rebnconvin(hx)
140
+
141
+ hx1 = self.rebnconv1(hxin)
142
+ hx = self.pool1(hx1)
143
+
144
+ hx2 = self.rebnconv2(hx)
145
+ hx = self.pool2(hx2)
146
+
147
+ hx3 = self.rebnconv3(hx)
148
+ hx = self.pool3(hx3)
149
+
150
+ hx4 = self.rebnconv4(hx)
151
+ hx = self.pool4(hx4)
152
+
153
+ hx5 = self.rebnconv5(hx)
154
+
155
+ hx6 = self.rebnconv6(hx5)
156
+
157
+
158
+ hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1))
159
+ hx5dup = _upsample_like(hx5d,hx4)
160
+
161
+ hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
162
+ hx4dup = _upsample_like(hx4d,hx3)
163
+
164
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
165
+ hx3dup = _upsample_like(hx3d,hx2)
166
+
167
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
168
+ hx2dup = _upsample_like(hx2d,hx1)
169
+
170
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
171
+
172
+ return hx1d + hxin
173
+
174
+ ### RSU-5 ###
175
+ class RSU5(nn.Module):#UNet05DRES(nn.Module):
176
+
177
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
178
+ super(RSU5,self).__init__()
179
+
180
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
181
+
182
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
183
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
184
+
185
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
186
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
187
+
188
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
189
+ self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
190
+
191
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
192
+
193
+ self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)
194
+
195
+ self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
196
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
197
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
198
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
199
+
200
+ def forward(self,x):
201
+
202
+ hx = x
203
+
204
+ hxin = self.rebnconvin(hx)
205
+
206
+ hx1 = self.rebnconv1(hxin)
207
+ hx = self.pool1(hx1)
208
+
209
+ hx2 = self.rebnconv2(hx)
210
+ hx = self.pool2(hx2)
211
+
212
+ hx3 = self.rebnconv3(hx)
213
+ hx = self.pool3(hx3)
214
+
215
+ hx4 = self.rebnconv4(hx)
216
+
217
+ hx5 = self.rebnconv5(hx4)
218
+
219
+ hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
220
+ hx4dup = _upsample_like(hx4d,hx3)
221
+
222
+ hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
223
+ hx3dup = _upsample_like(hx3d,hx2)
224
+
225
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
226
+ hx2dup = _upsample_like(hx2d,hx1)
227
+
228
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
229
+
230
+ return hx1d + hxin
231
+
232
+ ### RSU-4 ###
233
+ class RSU4(nn.Module):#UNet04DRES(nn.Module):
234
+
235
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
236
+ super(RSU4,self).__init__()
237
+
238
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
239
+
240
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
241
+ self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
242
+
243
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
244
+ self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
245
+
246
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
247
+
248
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)
249
+
250
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
251
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
252
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
253
+
254
+ def forward(self,x):
255
+
256
+ hx = x
257
+
258
+ hxin = self.rebnconvin(hx)
259
+
260
+ hx1 = self.rebnconv1(hxin)
261
+ hx = self.pool1(hx1)
262
+
263
+ hx2 = self.rebnconv2(hx)
264
+ hx = self.pool2(hx2)
265
+
266
+ hx3 = self.rebnconv3(hx)
267
+
268
+ hx4 = self.rebnconv4(hx3)
269
+
270
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
271
+ hx3dup = _upsample_like(hx3d,hx2)
272
+
273
+ hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
274
+ hx2dup = _upsample_like(hx2d,hx1)
275
+
276
+ hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))
277
+
278
+ return hx1d + hxin
279
+
280
+ ### RSU-4F ###
281
+ class RSU4F(nn.Module):#UNet04FRES(nn.Module):
282
+
283
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
284
+ super(RSU4F,self).__init__()
285
+
286
+ self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)
287
+
288
+ self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
289
+ self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
290
+ self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)
291
+
292
+ self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)
293
+
294
+ self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
295
+ self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
296
+ self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)
297
+
298
+ def forward(self,x):
299
+
300
+ hx = x
301
+
302
+ hxin = self.rebnconvin(hx)
303
+
304
+ hx1 = self.rebnconv1(hxin)
305
+ hx2 = self.rebnconv2(hx1)
306
+ hx3 = self.rebnconv3(hx2)
307
+
308
+ hx4 = self.rebnconv4(hx3)
309
+
310
+ hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
311
+ hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
312
+ hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))
313
+
314
+ return hx1d + hxin
315
+
316
+
317
+ ##### U^2-Net ####
318
+ class U2NET(nn.Module):
319
+
320
+ def __init__(self,in_ch=3,out_ch=1):
321
+ super(U2NET,self).__init__()
322
+
323
+ self.stage1 = RSU7(in_ch,32,64)
324
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
325
+
326
+ self.stage2 = RSU6(64,32,128)
327
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
328
+
329
+ self.stage3 = RSU5(128,64,256)
330
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
331
+
332
+ self.stage4 = RSU4(256,128,512)
333
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
334
+
335
+ self.stage5 = RSU4F(512,256,512)
336
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
337
+
338
+ self.stage6 = RSU4F(512,256,512)
339
+
340
+ # decoder
341
+ self.stage5d = RSU4F(1024,256,512)
342
+ self.stage4d = RSU4(1024,128,256)
343
+ self.stage3d = RSU5(512,64,128)
344
+ self.stage2d = RSU6(256,32,64)
345
+ self.stage1d = RSU7(128,16,64)
346
+
347
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
348
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
349
+ self.side3 = nn.Conv2d(128,out_ch,3,padding=1)
350
+ self.side4 = nn.Conv2d(256,out_ch,3,padding=1)
351
+ self.side5 = nn.Conv2d(512,out_ch,3,padding=1)
352
+ self.side6 = nn.Conv2d(512,out_ch,3,padding=1)
353
+
354
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
355
+
356
+ def forward(self,x):
357
+
358
+ hx = x
359
+
360
+ #stage 1
361
+ hx1 = self.stage1(hx)
362
+ hx = self.pool12(hx1)
363
+
364
+ #stage 2
365
+ hx2 = self.stage2(hx)
366
+ hx = self.pool23(hx2)
367
+
368
+ #stage 3
369
+ hx3 = self.stage3(hx)
370
+ hx = self.pool34(hx3)
371
+
372
+ #stage 4
373
+ hx4 = self.stage4(hx)
374
+ hx = self.pool45(hx4)
375
+
376
+ #stage 5
377
+ hx5 = self.stage5(hx)
378
+ hx = self.pool56(hx5)
379
+
380
+ #stage 6
381
+ hx6 = self.stage6(hx)
382
+ hx6up = _upsample_like(hx6,hx5)
383
+
384
+ #-------------------- decoder --------------------
385
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
386
+ hx5dup = _upsample_like(hx5d,hx4)
387
+
388
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
389
+ hx4dup = _upsample_like(hx4d,hx3)
390
+
391
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
392
+ hx3dup = _upsample_like(hx3d,hx2)
393
+
394
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
395
+ hx2dup = _upsample_like(hx2d,hx1)
396
+
397
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
398
+
399
+
400
+ #side output
401
+ d1 = self.side1(hx1d)
402
+
403
+ d2 = self.side2(hx2d)
404
+ d2 = _upsample_like(d2,d1)
405
+
406
+ d3 = self.side3(hx3d)
407
+ d3 = _upsample_like(d3,d1)
408
+
409
+ d4 = self.side4(hx4d)
410
+ d4 = _upsample_like(d4,d1)
411
+
412
+ d5 = self.side5(hx5d)
413
+ d5 = _upsample_like(d5,d1)
414
+
415
+ d6 = self.side6(hx6)
416
+ d6 = _upsample_like(d6,d1)
417
+
418
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
419
+
420
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
421
+
422
+ ### U^2-Net small ###
423
+ class U2NETP(nn.Module):
424
+
425
+ def __init__(self,in_ch=3,out_ch=1):
426
+ super(U2NETP,self).__init__()
427
+
428
+ self.stage1 = RSU7(in_ch,16,64)
429
+ self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
430
+
431
+ self.stage2 = RSU6(64,16,64)
432
+ self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
433
+
434
+ self.stage3 = RSU5(64,16,64)
435
+ self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
436
+
437
+ self.stage4 = RSU4(64,16,64)
438
+ self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
439
+
440
+ self.stage5 = RSU4F(64,16,64)
441
+ self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
442
+
443
+ self.stage6 = RSU4F(64,16,64)
444
+
445
+ # decoder
446
+ self.stage5d = RSU4F(128,16,64)
447
+ self.stage4d = RSU4(128,16,64)
448
+ self.stage3d = RSU5(128,16,64)
449
+ self.stage2d = RSU6(128,16,64)
450
+ self.stage1d = RSU7(128,16,64)
451
+
452
+ self.side1 = nn.Conv2d(64,out_ch,3,padding=1)
453
+ self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
454
+ self.side3 = nn.Conv2d(64,out_ch,3,padding=1)
455
+ self.side4 = nn.Conv2d(64,out_ch,3,padding=1)
456
+ self.side5 = nn.Conv2d(64,out_ch,3,padding=1)
457
+ self.side6 = nn.Conv2d(64,out_ch,3,padding=1)
458
+
459
+ self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
460
+
461
+ def forward(self,x):
462
+
463
+ hx = x
464
+
465
+ #stage 1
466
+ hx1 = self.stage1(hx)
467
+ hx = self.pool12(hx1)
468
+
469
+ #stage 2
470
+ hx2 = self.stage2(hx)
471
+ hx = self.pool23(hx2)
472
+
473
+ #stage 3
474
+ hx3 = self.stage3(hx)
475
+ hx = self.pool34(hx3)
476
+
477
+ #stage 4
478
+ hx4 = self.stage4(hx)
479
+ hx = self.pool45(hx4)
480
+
481
+ #stage 5
482
+ hx5 = self.stage5(hx)
483
+ hx = self.pool56(hx5)
484
+
485
+ #stage 6
486
+ hx6 = self.stage6(hx)
487
+ hx6up = _upsample_like(hx6,hx5)
488
+
489
+ #decoder
490
+ hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
491
+ hx5dup = _upsample_like(hx5d,hx4)
492
+
493
+ hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
494
+ hx4dup = _upsample_like(hx4d,hx3)
495
+
496
+ hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
497
+ hx3dup = _upsample_like(hx3d,hx2)
498
+
499
+ hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
500
+ hx2dup = _upsample_like(hx2d,hx1)
501
+
502
+ hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
503
+
504
+
505
+ #side output
506
+ d1 = self.side1(hx1d)
507
+
508
+ d2 = self.side2(hx2d)
509
+ d2 = _upsample_like(d2,d1)
510
+
511
+ d3 = self.side3(hx3d)
512
+ d3 = _upsample_like(d3,d1)
513
+
514
+ d4 = self.side4(hx4d)
515
+ d4 = _upsample_like(d4,d1)
516
+
517
+ d5 = self.side5(hx5d)
518
+ d5 = _upsample_like(d5,d1)
519
+
520
+ d6 = self.side6(hx6)
521
+ d6 = _upsample_like(d6,d1)
522
+
523
+ d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
524
+
525
+ return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)