File size: 7,146 Bytes
f41fc03
42a199f
95be706
42a199f
 
00cb073
f41fc03
9bfc28a
00cb073
42a199f
5d5d49d
00cb073
8e5d3ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a805da0
5d5d49d
 
 
 
00cb073
5d5d49d
 
00cb073
5d5d49d
8e5d3ec
a9b6686
8e5d3ec
0c5f90a
 
 
 
00cb073
5d5d49d
0c5f90a
00cb073
42a199f
00cb073
42a199f
 
 
0c5f90a
a9b6686
8213472
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from __future__ import annotations
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
import shutil
import torch

from vtoonify_model import Model

app = FastAPI()
model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')

    def load_model(self, style_type: str) -> tuple[torch.Tensor, str]:
        if 'illustration' in style_type:
            self.color_transfer = True
        else:
            self.color_transfer = False
        if style_type not in self.style_types.keys():
            return None, 'Oops, wrong Style Type. Please select a valid model.'
        self.style_name = style_type
        model_path, ind = self.style_types[style_type]
        style_path = os.path.join('models',os.path.dirname(model_path),'exstyle_code.npy')
        self.vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,'models/'+model_path), 
                                            map_location=lambda storage, loc: storage)['g_ema'])
        tmp = np.load(huggingface_hub.hf_hub_download(MODEL_REPO, style_path), allow_pickle=True).item()
        exstyle = torch.tensor(tmp[list(tmp.keys())[ind]]).to(self.device)
        with torch.no_grad():  
            exstyle = self.vtoonify.zplus2wplus(exstyle)
        return exstyle, 'Model of %s loaded.'%(style_type)
    
    def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
        message = 'Error: no face detected! Please retry or change the photo.'
        paras = get_video_crop_parameter(frame, self.landmarkpredictor, [left, right, top, bottom])
        instyle = None
        h, w, scale = 0, 0, 0
        if paras is not None:
            h,w,top,bottom,left,right,scale = paras
            H, W = int(bottom-top), int(right-left)
            # for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
            kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
            if scale <= 0.75:
                frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
            if scale <= 0.375:
                frame = cv2.sepFilter2D(frame, -1, kernel_1d, kernel_1d)
            frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
            with torch.no_grad():
                I = align_face(frame, self.landmarkpredictor)
                if I is not None:
                    I = self.transform(I).unsqueeze(dim=0).to(self.device)
                    instyle = self.pspencoder(I)
                    instyle = self.vtoonify.zplus2wplus(instyle)
                    message = 'Successfully rescale the frame to (%d, %d)'%(bottom-top, right-left)
                else:
                    frame = np.zeros((256,256,3), np.uint8)
        else:
            frame = np.zeros((256,256,3), np.uint8)
        if return_para:
            return frame, instyle, message, w, h, top, bottom, left, right, scale
        return frame, instyle, message
    
    #@torch.inference_mode()
    def detect_and_align_image(self, image: str, top: int, bottom: int, left: int, right: int
                              ) -> tuple[np.ndarray, torch.Tensor, str]:
        if image is None:
            return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load empty file.'
        frame = cv2.imread(image)
        if frame is None:
            return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load the image.'       
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        return self.detect_and_align(frame, top, bottom, left, right)
    
    def detect_and_align_video(self, video: str, top: int, bottom: int, left: int, right: int
                              ) -> tuple[np.ndarray, torch.Tensor, str]:
        if video is None:
            return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load empty file.'
        video_cap = cv2.VideoCapture(video)
        if video_cap.get(7) == 0:
            video_cap.release()
            return np.zeros((256,256,3), np.uint8), torch.zeros(1,18,512).to(self.device), 'Error: fail to load the video.'
        success, frame = video_cap.read()
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        video_cap.release()
        return self.detect_and_align(frame, top, bottom, left, right)
    
  
    def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple[np.ndarray, str]:
        #print(style_type + ' ' + self.style_name)
        if instyle is None or aligned_face is None:
            return np.zeros((256,256,3), np.uint8), 'Opps, something wrong with the input. Please go to Step 2 and Rescale Image/First Frame again.'
        if self.style_name != style_type:
            exstyle, _  = self.load_model(style_type)
        if exstyle is None:
            return np.zeros((256,256,3), np.uint8), 'Opps, something wrong with the style type. Please go to Step 1 and load model again.'
        with torch.no_grad():
            if self.color_transfer:
                s_w = exstyle
            else:
                s_w = instyle.clone()
                s_w[:,:7] = exstyle[:,:7]

            x = self.transform(aligned_face).unsqueeze(dim=0).to(self.device)
            x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0], 
                                scale_factor=0.5, recompute_scale_factor=False).detach()
            inputs = torch.cat((x, x_p/16.), dim=1)
            y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = style_degree)        
            y_tilde = torch.clamp(y_tilde, -1, 1)
        print('*** Toonify %dx%d image with style of %s'%(y_tilde.shape[2], y_tilde.shape[3], style_type))
        return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image with style of %s'%(self.style_name)
    
@app.post("/upload/")
async def process_image(file: UploadFile = File(...)):
    # Save the uploaded image locally
    with open("uploaded_image.jpg", "wb") as buffer:
        shutil.copyfileobj(file.file, buffer)

    # Load the model (assuming 'cartoon1' is always used)
    exstyle, load_info = model.load_model('cartoon1')

    # Process the uploaded image
    top, bottom, left, right = 200, 200, 200, 200  
    aligned_face, _, input_info = model.detect_and_align_image("uploaded_image.jpg", top, bottom, left, right)
    processed_image, message = model.image_toonify(aligned_face, instyle=exstyle, exstyle=exstyle, style_degree=0.5, style_type='cartoon1')

    # Save the processed image
    with open("result_image.jpg", "wb") as result_buffer:
        result_buffer.write(processed_image)

    # Return the processed image
    return FileResponse("result_image.jpg", media_type="image/jpeg", headers={"Content-Disposition": "attachment; filename=result_image.jpg"})

app.mount("/", StaticFiles(directory="AB", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/AB/index.html", media_type="text/html")