import os import onnxruntime as rt from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from PIL import Image, ImageOps import numpy as np import io import face_detection # Ensure this is the adjusted face_detection.py # Initialize FastAPI app app = FastAPI() # Allow CORS for your frontend application app.add_middleware( CORSMiddleware, allow_origins=["*"], # Change this to your frontend's URL in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load the ONNX model MODEL_FILE = "ffhqu2vintage512_pix2pixHD_v1E11-inp2inst-simp.onnx" so = rt.SessionOptions() so.inter_op_num_threads = 4 so.intra_op_num_threads = 4 session = rt.InferenceSession(MODEL_FILE, sess_options=so) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name def array_to_image(array_in): array_in = np.squeeze(255 * (array_in + 1) / 2) array_in = np.transpose(array_in, (1, 2, 0)) im = Image.fromarray(array_in.astype(np.uint8)) return im def image_as_array(image_in): im_array = np.array(image_in, np.float32) im_array = (im_array / 255) * 2 - 1 im_array = np.transpose(im_array, (2, 0, 1)) im_array = np.expand_dims(im_array, 0) return im_array def find_aligned_face(image_in, size=512): aligned_image, n_faces, quad = face_detection.align(image_in, face_index=0, output_size=size) return aligned_image, n_faces, quad def align_first_face(image_in, size=512): aligned_image, n_faces, quad = find_aligned_face(image_in, size=size) if n_faces == 0: try: image_in = ImageOps.exif_transpose(image_in) except: print("exif problem, not rotating") image_in = image_in.resize((size, size)) im_array = image_as_array(image_in) else: im_array = image_as_array(aligned_image) return im_array def img_concat_h(im1, im2): dst = Image.new('RGB', (im1.width + im2.width, im1.height)) dst.paste(im1, (0, 0)) dst.paste(im2, (im1.width, 0)) return dst def face2vintage(img: Image.Image, size: int) -> Image.Image: aligned_img = align_first_face(img) if aligned_img is None: return None output = session.run([output_name], {input_name: aligned_img})[0] output = array_to_image(output) aligned_img = array_to_image(aligned_img).resize((output.width, output.height)) output = img_concat_h(aligned_img, output) return output @app.post("/process_image/") async def process_image(file: UploadFile = File(...)): try: # Read the image file image_bytes = await file.read() image = Image.open(io.BytesIO(image_bytes)) # Process the image processed_image = face2vintage(image, 512) # Convert the processed image to bytes if processed_image is None: raise HTTPException(status_code=400, detail="Could not process image.") img_byte_arr = io.BytesIO() processed_image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return StreamingResponse(img_byte_arr, media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=str(e))