V
File size: 3,320 Bytes
7109a9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import onnxruntime as rt
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from PIL import Image, ImageOps
import numpy as np
import io
import face_detection  # Ensure this is the adjusted face_detection.py

# Initialize FastAPI app
app = FastAPI()

# Allow CORS for your frontend application
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Change this to your frontend's URL in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Load the ONNX model
MODEL_FILE = "ffhqu2vintage512_pix2pixHD_v1E11-inp2inst-simp.onnx"
so = rt.SessionOptions()
so.inter_op_num_threads = 4
so.intra_op_num_threads = 4
session = rt.InferenceSession(MODEL_FILE, sess_options=so)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

def array_to_image(array_in):
    array_in = np.squeeze(255 * (array_in + 1) / 2)
    array_in = np.transpose(array_in, (1, 2, 0))
    im = Image.fromarray(array_in.astype(np.uint8))
    return im

def image_as_array(image_in):
    im_array = np.array(image_in, np.float32)
    im_array = (im_array / 255) * 2 - 1
    im_array = np.transpose(im_array, (2, 0, 1))
    im_array = np.expand_dims(im_array, 0)
    return im_array

def find_aligned_face(image_in, size=512):
    aligned_image, n_faces, quad = face_detection.align(image_in, face_index=0, output_size=size)
    return aligned_image, n_faces, quad

def align_first_face(image_in, size=512):
    aligned_image, n_faces, quad = find_aligned_face(image_in, size=size)
    if n_faces == 0:
        try:
            image_in = ImageOps.exif_transpose(image_in)
        except:
            print("exif problem, not rotating")
        image_in = image_in.resize((size, size))
        im_array = image_as_array(image_in)
    else:
        im_array = image_as_array(aligned_image)

    return im_array

def img_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

def face2vintage(img: Image.Image, size: int) -> Image.Image:
    aligned_img = align_first_face(img)
    if aligned_img is None:
        return None

    output = session.run([output_name], {input_name: aligned_img})[0]
    output = array_to_image(output)
    aligned_img = array_to_image(aligned_img).resize((output.width, output.height))
    output = img_concat_h(aligned_img, output)

    return output

@app.post("/process_image/")
async def process_image(file: UploadFile = File(...)):
    try:
        # Read the image file
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes))
        
        # Process the image
        processed_image = face2vintage(image, 512)
        
        # Convert the processed image to bytes
        if processed_image is None:
            raise HTTPException(status_code=400, detail="Could not process image.")

        img_byte_arr = io.BytesIO()
        processed_image.save(img_byte_arr, format='PNG')
        img_byte_arr.seek(0)

        return StreamingResponse(img_byte_arr, media_type="image/png")

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))