|
import os |
|
import onnxruntime as rt |
|
from fastapi import FastAPI, File, UploadFile, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import StreamingResponse |
|
from PIL import Image, ImageOps |
|
import numpy as np |
|
import io |
|
import face_detection |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
MODEL_FILE = "ffhqu2vintage512_pix2pixHD_v1E11-inp2inst-simp.onnx" |
|
so = rt.SessionOptions() |
|
so.inter_op_num_threads = 4 |
|
so.intra_op_num_threads = 4 |
|
session = rt.InferenceSession(MODEL_FILE, sess_options=so) |
|
input_name = session.get_inputs()[0].name |
|
output_name = session.get_outputs()[0].name |
|
|
|
def array_to_image(array_in): |
|
array_in = np.squeeze(255 * (array_in + 1) / 2) |
|
array_in = np.transpose(array_in, (1, 2, 0)) |
|
im = Image.fromarray(array_in.astype(np.uint8)) |
|
return im |
|
|
|
def image_as_array(image_in): |
|
im_array = np.array(image_in, np.float32) |
|
im_array = (im_array / 255) * 2 - 1 |
|
im_array = np.transpose(im_array, (2, 0, 1)) |
|
im_array = np.expand_dims(im_array, 0) |
|
return im_array |
|
|
|
def find_aligned_face(image_in, size=512): |
|
aligned_image, n_faces, quad = face_detection.align(image_in, face_index=0, output_size=size) |
|
return aligned_image, n_faces, quad |
|
|
|
def align_first_face(image_in, size=512): |
|
aligned_image, n_faces, quad = find_aligned_face(image_in, size=size) |
|
if n_faces == 0: |
|
try: |
|
image_in = ImageOps.exif_transpose(image_in) |
|
except: |
|
print("exif problem, not rotating") |
|
image_in = image_in.resize((size, size)) |
|
im_array = image_as_array(image_in) |
|
else: |
|
im_array = image_as_array(aligned_image) |
|
|
|
return im_array |
|
|
|
def img_concat_h(im1, im2): |
|
dst = Image.new('RGB', (im1.width + im2.width, im1.height)) |
|
dst.paste(im1, (0, 0)) |
|
dst.paste(im2, (im1.width, 0)) |
|
return dst |
|
|
|
def face2vintage(img: Image.Image, size: int) -> Image.Image: |
|
aligned_img = align_first_face(img) |
|
if aligned_img is None: |
|
return None |
|
|
|
output = session.run([output_name], {input_name: aligned_img})[0] |
|
output = array_to_image(output) |
|
aligned_img = array_to_image(aligned_img).resize((output.width, output.height)) |
|
output = img_concat_h(aligned_img, output) |
|
|
|
return output |
|
|
|
@app.post("/process_image/") |
|
async def process_image(file: UploadFile = File(...)): |
|
try: |
|
|
|
image_bytes = await file.read() |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
processed_image = face2vintage(image, 512) |
|
|
|
|
|
if processed_image is None: |
|
raise HTTPException(status_code=400, detail="Could not process image.") |
|
|
|
img_byte_arr = io.BytesIO() |
|
processed_image.save(img_byte_arr, format='PNG') |
|
img_byte_arr.seek(0) |
|
|
|
return StreamingResponse(img_byte_arr, media_type="image/png") |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |