import os
import onnxruntime as rt
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from PIL import Image, ImageOps
import numpy as np
import io
import face_detection # Ensure this is the adjusted face_detection.py
# Initialize FastAPI app
app = FastAPI()
# Allow CORS for your frontend application
allow_origins=["*"], # Change this to your frontend's URL in production
# Load the ONNX model
MODEL_FILE = "ffhqu2vintage512_pix2pixHD_v1E11-inp2inst-simp.onnx"
so = rt.SessionOptions()
so.inter_op_num_threads = 4
so.intra_op_num_threads = 4
session = rt.InferenceSession(MODEL_FILE, sess_options=so)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
def array_to_image(array_in):
array_in = np.squeeze(255 * (array_in + 1) / 2)
array_in = np.transpose(array_in, (1, 2, 0))
im = Image.fromarray(array_in.astype(np.uint8))
return im
def image_as_array(image_in):
im_array = np.array(image_in, np.float32)
im_array = (im_array / 255) * 2 - 1
im_array = np.transpose(im_array, (2, 0, 1))
im_array = np.expand_dims(im_array, 0)
return im_array
def find_aligned_face(image_in, size=512):
aligned_image, n_faces, quad = face_detection.align(image_in, face_index=0, output_size=size)
return aligned_image, n_faces, quad
def align_first_face(image_in, size=512):
aligned_image, n_faces, quad = find_aligned_face(image_in, size=size)
if n_faces == 0:
image_in = ImageOps.exif_transpose(image_in)
print("exif problem, not rotating")
image_in = image_in.resize((size, size))
im_array = image_as_array(image_in)
im_array = image_as_array(aligned_image)
return im_array
def img_concat_h(im1, im2):
dst = Image.new('RGB', (im1.width + im2.width, im1.height))
dst.paste(im1, (0, 0))
dst.paste(im2, (im1.width, 0))
return dst
def face2vintage(img: Image.Image, size: int) -> Image.Image:
aligned_img = align_first_face(img)
if aligned_img is None:
return None
output = session.run([output_name], {input_name: aligned_img})[0]
output = array_to_image(output)
aligned_img = array_to_image(aligned_img).resize((output.width, output.height))
output = img_concat_h(aligned_img, output)
return output
async def process_image(file: UploadFile = File(...)):
# Read the image file
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
# Process the image
processed_image = face2vintage(image, 512)
# Convert the processed image to bytes
if processed_image is None:
raise HTTPException(status_code=400, detail="Could not process image.")
img_byte_arr = io.BytesIO()
processed_image.save(img_byte_arr, format='PNG')
return StreamingResponse(img_byte_arr, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))