Spaces:
Running
Running
# from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
# from pydantic import BaseModel | |
# import numpy as np | |
# from PIL import Image | |
# import io, uuid, os, shutil, timeit | |
# from datetime import datetime | |
# from fastapi.staticfiles import StaticFiles | |
# from fastapi.middleware.cors import CORSMiddleware | |
# # import your three wrappers | |
# from app import predict_simple, predict_middle, predict_full | |
# app = FastAPI() | |
# # allow CORS if needed | |
# app.add_middleware( | |
# CORSMiddleware, | |
# allow_origins=["*"], | |
# allow_methods=["*"], | |
# allow_headers=["*"], | |
# ) | |
# BASE_URL = "https://snapanddtraceapp-988917236820.us-central1.run.app" | |
# OUTPUT_DIR = os.path.abspath("./outputs") | |
# os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") | |
# UPDATES_DIR = os.path.abspath("./updates") | |
# os.makedirs(UPDATES_DIR, exist_ok=True) | |
# app.mount("/updates", StaticFiles(directory=UPDATES_DIR), name="updates") | |
# def save_and_build_urls( | |
# session_id: str, | |
# output_image: np.ndarray, | |
# outlines: np.ndarray, | |
# dxf_path: str, | |
# mask: np.ndarray | |
# ): | |
# """Helper to save all four artifacts and return public URLs.""" | |
# request_dir = os.path.join(OUTPUT_DIR, session_id) | |
# os.makedirs(request_dir, exist_ok=True) | |
# # filenames | |
# out_fn = "overlay.jpg" | |
# outlines_fn = "outlines.jpg" | |
# mask_fn = "mask.jpg" | |
# current_date = datetime.now().strftime("%d-%m-%Y") | |
# dxf_fn = f"out_{current_date}_{session_id}.dxf" | |
# # full paths | |
# out_path = os.path.join(request_dir, out_fn) | |
# outlines_path = os.path.join(request_dir, outlines_fn) | |
# mask_path = os.path.join(request_dir, mask_fn) | |
# new_dxf_path = os.path.join(request_dir, dxf_fn) | |
# # save images | |
# Image.fromarray(output_image).save(out_path) | |
# Image.fromarray(outlines).save(outlines_path) | |
# Image.fromarray(mask).save(mask_path) | |
# # copy dx file | |
# if os.path.exists(dxf_path): | |
# shutil.copy(dxf_path, new_dxf_path) | |
# else: | |
# # fallback if your DXF generator returns bytes or string | |
# with open(new_dxf_path, "wb") as f: | |
# if isinstance(dxf_path, (bytes, bytearray)): | |
# f.write(dxf_path) | |
# else: | |
# f.write(str(dxf_path).encode("utf-8")) | |
# # build URLs | |
# return { | |
# "output_image_url": f"{BASE_URL}/outputs/{session_id}/{out_fn}", | |
# "outlines_url": f"{BASE_URL}/outputs/{session_id}/{outlines_fn}", | |
# "mask_url": f"{BASE_URL}/outputs/{session_id}/{mask_fn}", | |
# "dxf_url": f"{BASE_URL}/outputs/{session_id}/{dxf_fn}", | |
# } | |
# @app.post("/predict1") | |
# async def predict1_api( | |
# file: UploadFile = File(...) | |
# ): | |
# """ | |
# Simple predict: only image β overlay, outlines, mask, DXF | |
# """ | |
# session_id = str(uuid.uuid4()) | |
# try: | |
# img_bytes = await file.read() | |
# image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
# except Exception: | |
# raise HTTPException(400, "Invalid image upload") | |
# try: | |
# start = timeit.default_timer() | |
# out_img, outlines, dxf_path, mask = predict_simple(image) | |
# elapsed = timeit.default_timer() - start | |
# print(f"[{session_id}] predict1 in {elapsed:.2f}s") | |
# return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask) | |
# except Exception as e: | |
# raise HTTPException(500, f"predict1 failed: {e}") | |
# except ReferenceBoxNotDetectedError: | |
# raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.") | |
# except FingerCutOverlapError: | |
# raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.") | |
# @app.post("/predict2") | |
# async def predict2_api( | |
# file: UploadFile = File(...), | |
# enable_fillet: str = Form(..., regex="^(On|Off)$"), | |
# fillet_value_mm: float = Form(...) | |
# ): | |
# """ | |
# Middle predict: image + fillet toggle + fillet value β overlay, outlines, mask, DXF | |
# """ | |
# session_id = str(uuid.uuid4()) | |
# try: | |
# img_bytes = await file.read() | |
# image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
# except Exception: | |
# raise HTTPException(400, "Invalid image upload") | |
# try: | |
# start = timeit.default_timer() | |
# out_img, outlines, dxf_path, mask = predict_middle( | |
# image, enable_fillet, fillet_value_mm | |
# ) | |
# elapsed = timeit.default_timer() - start | |
# print(f"[{session_id}] predict2 in {elapsed:.2f}s") | |
# return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask) | |
# except Exception as e: | |
# raise HTTPException(500, f"predict2 failed: {e}") | |
# except ReferenceBoxNotDetectedError: | |
# raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.") | |
# except FingerCutOverlapError: | |
# raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.") | |
# @app.post("/predict3") | |
# async def predict3_api( | |
# file: UploadFile = File(...), | |
# enable_fillet: str = Form(..., regex="^(On|Off)$"), | |
# fillet_value_mm: float = Form(...), | |
# enable_finger_cut: str = Form(..., regex="^(On|Off)$") | |
# ): | |
# """ | |
# Full predict: image + fillet toggle/value + finger-cut toggle β overlay, outlines, mask, DXF | |
# """ | |
# session_id = str(uuid.uuid4()) | |
# try: | |
# img_bytes = await file.read() | |
# image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
# except Exception: | |
# raise HTTPException(400, "Invalid image upload") | |
# try: | |
# start = timeit.default_timer() | |
# out_img, outlines, dxf_path, mask = predict_full( | |
# image, enable_fillet, fillet_value_mm, enable_finger_cut | |
# ) | |
# elapsed = timeit.default_timer() - start | |
# print(f"[{session_id}] predict3 in {elapsed:.2f}s") | |
# return save_and_build_urls(session_id, out_img, outlines, dxf_path, mask) | |
# except Exception as e: | |
# raise HTTPException(500, f"predict3 failed: {e}") | |
# except ReferenceBoxNotDetectedError: | |
# raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.") | |
# except FingerCutOverlapError: | |
# raise HTTPException(status_code=400, detail="There was an overlap with fingercuts!s Please try again to generate dxf.") | |
# @app.post("/update") | |
# async def update_files( | |
# output_image: UploadFile = File(...), | |
# outlines_image: UploadFile = File(...), | |
# mask_image: UploadFile = File(...), | |
# dxf_file: UploadFile = File(...) | |
# ): | |
# session_id = str(uuid.uuid4()) | |
# update_dir = os.path.join(UPDATES_DIR, session_id) | |
# os.makedirs(update_dir, exist_ok=True) | |
# try: | |
# upload_map = { | |
# "output_image": output_image, | |
# "outlines_image": outlines_image, | |
# "mask_image": mask_image, | |
# "dxf_file": dxf_file, | |
# } | |
# urls = {} | |
# for key, up in upload_map.items(): | |
# fn = up.filename | |
# path = os.path.join(update_dir, fn) | |
# with open(path, "wb") as f: | |
# shutil.copyfileobj(up.file, f) | |
# urls[key] = f"{BASE_URL}/updates/{session_id}/{fn}" | |
# return {"session_id": session_id, "uploaded": urls} | |
# except Exception as e: | |
# raise HTTPException(500, f"Update failed: {e}") | |
# if __name__ == "__main__": | |
# import uvicorn | |
# port = int(os.environ.get("PORT", 8082)) | |
# print(f"Starting FastAPI server on 0.0.0.0:{port}...") | |
# uvicorn.run(app, host="0.0.0.0", port=port) | |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
from pydantic import BaseModel | |
import numpy as np | |
from PIL import Image | |
import io, uuid, os, shutil, timeit | |
from datetime import datetime | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import FileResponse | |
# import your three wrappers | |
from app import predict_simple, predict_middle, predict_full | |
from app import ( | |
predict_simple, predict_middle, predict_full, | |
ReferenceBoxNotDetectedError, | |
FingerCutOverlapError | |
) | |
app = FastAPI() | |
# allow CORS if needed | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
BASE_URL = "https://snapanddtraceapp-988917236820.us-central1.run.app" | |
OUTPUT_DIR = os.path.abspath("./outputs") | |
os.makedirs(OUTPUT_DIR, exist_ok=True) | |
UPDATES_DIR = os.path.abspath("./updates") | |
os.makedirs(UPDATES_DIR, exist_ok=True) | |
# Mount static directories with normal StaticFiles | |
app.mount("/outputs", StaticFiles(directory=OUTPUT_DIR), name="outputs") | |
app.mount("/updates", StaticFiles(directory=UPDATES_DIR), name="updates") | |
def save_and_build_urls( | |
session_id: str, | |
output_image: np.ndarray, | |
outlines: np.ndarray, | |
dxf_path: str, | |
mask: np.ndarray, | |
endpoint_type: str, | |
fillet_value: float = None, | |
finger_cut: str = None | |
): | |
"""Helper to save all four artifacts and return public URLs.""" | |
request_dir = os.path.join(OUTPUT_DIR, session_id) | |
os.makedirs(request_dir, exist_ok=True) | |
# filenames | |
out_fn = "overlay.jpg" | |
outlines_fn = "outlines.jpg" | |
mask_fn = "mask.jpg" | |
# Get current date | |
current_date = datetime.utcnow().strftime("%d-%m-%Y") | |
# Format fillet value with underscore instead of dot | |
fillet_str = f"{fillet_value:.2f}".replace(".", "_") if fillet_value is not None else None | |
# Determine DXF filename based on endpoint type | |
if endpoint_type == "predict1": | |
dxf_fn = f"DXF_{current_date}.dxf" | |
elif endpoint_type == "predict2": | |
dxf_fn = f"DXF_{current_date}.dxf" | |
elif endpoint_type == "predict3": | |
dxf_fn = f"DXF_{current_date}.dxf" | |
# full paths | |
out_path = os.path.join(request_dir, out_fn) | |
outlines_path = os.path.join(request_dir, outlines_fn) | |
mask_path = os.path.join(request_dir, mask_fn) | |
new_dxf_path = os.path.join(request_dir, dxf_fn) | |
# save images | |
Image.fromarray(output_image).save(out_path) | |
Image.fromarray(outlines).save(outlines_path) | |
Image.fromarray(mask).save(mask_path) | |
# copy dxf file | |
if os.path.exists(dxf_path): | |
shutil.copy(dxf_path, new_dxf_path) | |
else: | |
# fallback if your DXF generator returns bytes or string | |
with open(new_dxf_path, "wb") as f: | |
if isinstance(dxf_path, (bytes, bytearray)): | |
f.write(dxf_path) | |
else: | |
f.write(str(dxf_path).encode("utf-8")) | |
# build URLs with /download prefix for DXF | |
return { | |
"output_image_url": f"{BASE_URL}/outputs/{session_id}/{out_fn}", | |
"outlines_url": f"{BASE_URL}/outputs/{session_id}/{outlines_fn}", | |
"mask_url": f"{BASE_URL}/outputs/{session_id}/{mask_fn}", | |
"dxf_url": f"{BASE_URL}/download/{session_id}/{dxf_fn}", # Changed to use download endpoint | |
} | |
# Add new endpoint for downloading DXF files | |
async def download_file(session_id: str, filename: str): | |
file_path = os.path.join(OUTPUT_DIR, session_id, filename) | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=404, detail="File not found") | |
return FileResponse( | |
path=file_path, | |
filename=filename, | |
media_type="application/x-dxf", | |
headers={"Content-Disposition": f"attachment; filename={filename}"} | |
) | |
async def predict1_api( | |
file: UploadFile = File(...) | |
): | |
""" | |
Simple predict: only image β overlay, outlines, mask, DXF | |
DXF naming format: DXF_DD-MM-YYYY.dxf | |
""" | |
session_id = str(uuid.uuid4()) | |
try: | |
img_bytes = await file.read() | |
image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
except Exception: | |
raise HTTPException(400, "Invalid image upload") | |
try: | |
start = timeit.default_timer() | |
out_img, outlines, dxf_path, mask = predict_simple(image) | |
elapsed = timeit.default_timer() - start | |
print(f"[{session_id}] predict1 in {elapsed:.2f}s") | |
return save_and_build_urls( | |
session_id=session_id, | |
output_image=out_img, | |
outlines=outlines, | |
dxf_path=dxf_path, | |
mask=mask, | |
endpoint_type="predict1" | |
) | |
except ReferenceBoxNotDetectedError: | |
raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.") | |
except FingerCutOverlapError: | |
raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.") | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.") | |
async def predict2_api( | |
file: UploadFile = File(...), | |
enable_fillet: str = Form(..., regex="^(On|Off)$"), | |
fillet_value_mm: float = Form(...) | |
): | |
""" | |
Middle predict: image + fillet toggle + fillet value β overlay, outlines, mask, DXF | |
DXF naming format: DXF_DD-MM-YYYY_fillet-value_mm.dxf | |
""" | |
session_id = str(uuid.uuid4()) | |
try: | |
img_bytes = await file.read() | |
image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
except Exception: | |
raise HTTPException(400, "Invalid image upload") | |
try: | |
start = timeit.default_timer() | |
out_img, outlines, dxf_path, mask = predict_middle( | |
image, enable_fillet, fillet_value_mm | |
) | |
elapsed = timeit.default_timer() - start | |
print(f"[{session_id}] predict2 in {elapsed:.2f}s") | |
return save_and_build_urls( | |
session_id=session_id, | |
output_image=out_img, | |
outlines=outlines, | |
dxf_path=dxf_path, | |
mask=mask, | |
endpoint_type="predict2", | |
fillet_value=fillet_value_mm | |
) | |
except ReferenceBoxNotDetectedError: | |
raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.") | |
except FingerCutOverlapError: | |
raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.") | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.") | |
async def predict3_api( | |
file: UploadFile = File(...), | |
enable_fillet: str = Form(..., regex="^(On|Off)$"), | |
fillet_value_mm: float = Form(...), | |
enable_finger_cut: str = Form(..., regex="^(On|Off)$") | |
): | |
""" | |
Full predict: image + fillet toggle/value + finger-cut toggle β overlay, outlines, mask, DXF | |
DXF naming format: DXF_DD-MM-YYYY_fillet-value_mm_fingercut-On|Off.dxf | |
""" | |
session_id = str(uuid.uuid4()) | |
try: | |
img_bytes = await file.read() | |
image = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB")) | |
except Exception: | |
raise HTTPException(400, "Invalid image upload") | |
try: | |
start = timeit.default_timer() | |
out_img, outlines, dxf_path, mask = predict_full( | |
image, enable_fillet, fillet_value_mm, enable_finger_cut | |
) | |
elapsed = timeit.default_timer() - start | |
print(f"[{session_id}] predict3 in {elapsed:.2f}s") | |
return save_and_build_urls( | |
session_id=session_id, | |
output_image=out_img, | |
outlines=outlines, | |
dxf_path=dxf_path, | |
mask=mask, | |
endpoint_type="predict3", | |
fillet_value=fillet_value_mm, | |
finger_cut=enable_finger_cut | |
) | |
except ReferenceBoxNotDetectedError: | |
raise HTTPException(status_code=400, detail="Error detecting reference battery! Please try again with a clearer image.") | |
except FingerCutOverlapError: | |
raise HTTPException(status_code=400, detail="There was an overlap with fingercuts! Please try again to generate dxf.") | |
except HTTPException as e: | |
raise e | |
except Exception as e: | |
raise HTTPException(status_code=500, detail="Error detecting reference battery! Please try again with a clearer image.") | |
async def update_files( | |
output_image: UploadFile = File(...), | |
outlines_image: UploadFile = File(...), | |
mask_image: UploadFile = File(...), | |
dxf_file: UploadFile = File(...) | |
): | |
session_id = str(uuid.uuid4()) | |
update_dir = os.path.join(UPDATES_DIR, session_id) | |
os.makedirs(update_dir, exist_ok=True) | |
try: | |
upload_map = { | |
"output_image": output_image, | |
"outlines_image": outlines_image, | |
"mask_image": mask_image, | |
"dxf_file": dxf_file, | |
} | |
urls = {} | |
for key, up in upload_map.items(): | |
fn = up.filename | |
path = os.path.join(update_dir, fn) | |
with open(path, "wb") as f: | |
shutil.copyfileobj(up.file, f) | |
urls[key] = f"{BASE_URL}/updates/{session_id}/{fn}" | |
return {"session_id": session_id, "uploaded": urls} | |
except Exception as e: | |
raise HTTPException(500, f"Update failed: {e}") | |
from fastapi import Response | |
def health(): | |
return Response(content="OK", status_code=200) | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", 8080)) | |
print(f"Starting FastAPI server on 0.0.0.0:{port}...") | |
uvicorn.run(app, host="0.0.0.0", port=port) | |