Ashrafb commited on
Commit
ebb8d7c
·
verified ·
1 Parent(s): 0b1eb16

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -53
main.py CHANGED
@@ -1,9 +1,7 @@
1
  from __future__ import annotations
2
- from fastapi import FastAPI, File, UploadFile
3
- from fastapi.responses import FileResponse
4
- from fastapi.staticfiles import StaticFiles
5
  from fastapi import FastAPI, File, UploadFile, Form
6
- from fastapi.responses import FileResponse
 
7
  import torch
8
  import shutil
9
  import cv2
@@ -12,33 +10,18 @@ import dlib
12
  from torchvision import transforms
13
  import torch.nn.functional as F
14
  from vtoonify_model import Model # Importing the Model class from vtoonify_model.py
15
-
16
- import gradio as gr
17
- import pathlib
18
- import sys
19
- sys.path.insert(0, 'vtoonify')
20
-
21
- from util import load_psp_standalone, get_video_crop_parameter, tensor2cv2
22
- import torch
23
- import torch.nn as nn
24
- import numpy as np
25
- import dlib
26
- import cv2
27
  from model.vtoonify import VToonify
28
  from model.bisenet.model import BiSeNet
29
- import torch.nn.functional as F
30
- from torchvision import transforms
31
- from model.encoder.align_all_parallel import align_face
32
- import gc
33
  import huggingface_hub
34
  import os
 
35
 
36
  app = FastAPI()
37
- model = None
38
 
39
  MODEL_REPO = 'PKUWilliamYang/VToonify'
40
 
41
- class Model():
42
  def __init__(self, device):
43
  super().__init__()
44
 
@@ -53,19 +36,17 @@ class Model():
53
  self.pspencoder = self._load_encoder()
54
  self.transform = transforms.Compose([
55
  transforms.ToTensor(),
56
- transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
57
- ])
58
 
59
  self.vtoonify, self.exstyle = self._load_default_model()
60
  self.color_transfer = False
61
  self.style_name = 'cartoon1'
62
  self.video_limit_cpu = 100
63
  self.video_limit_gpu = 300
64
-
65
- @staticmethod
66
- def _create_dlib_landmark_model():
67
- return dlib.shape_predictor(huggingface_hub.hf_hub_download(MODEL_REPO,
68
- 'models/shape_predictor_68_face_landmarks.dat'))
69
 
70
  def _create_parsing_model(self):
71
  parsingpredictor = BiSeNet(n_classes=19)
@@ -75,16 +56,16 @@ class Model():
75
  return parsingpredictor
76
 
77
  def _load_encoder(self) -> nn.Module:
78
- style_encoder_path = huggingface_hub.hf_hub_download(MODEL_REPO,'models/encoder.pt')
79
  return load_psp_standalone(style_encoder_path, self.device)
80
 
81
  def _load_default_model(self) -> tuple[torch.Tensor, str]:
82
- vtoonify = VToonify(backbone = 'dualstylegan')
83
  vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,
84
  'models/vtoonify_d_cartoon/vtoonify_s026_d0.5.pt'),
85
  map_location=lambda storage, loc: storage)['g_ema'])
86
  vtoonify.to(self.device)
87
- tmp = np.load(huggingface_hub.hf_hub_download(MODEL_REPO,'models/vtoonify_d_cartoon/exstyle_code.npy'), allow_pickle=True).item()
88
  exstyle = torch.tensor(tmp[list(tmp.keys())[26]]).to(self.device)
89
  with torch.no_grad():
90
  exstyle = vtoonify.zplus2wplus(exstyle)
@@ -99,14 +80,14 @@ class Model():
99
  return None, 'Oops, wrong Style Type. Please select a valid model.'
100
  self.style_name = style_type
101
  model_path, ind = self.style_types[style_type]
102
- style_path = os.path.join('models',os.path.dirname(model_path),'exstyle_code.npy')
103
- self.vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,'models/'+model_path),
104
  map_location=lambda storage, loc: storage)['g_ema'])
105
  tmp = np.load(huggingface_hub.hf_hub_download(MODEL_REPO, style_path), allow_pickle=True).item()
106
  exstyle = torch.tensor(tmp[list(tmp.keys())[ind]]).to(self.device)
107
  with torch.no_grad():
108
  exstyle = self.vtoonify.zplus2wplus(exstyle)
109
- return exstyle, 'Model of %s loaded.'%(style_type)
110
 
111
  def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
112
  message = 'Error: no face detected! Please retry or change the photo.'
@@ -114,7 +95,7 @@ class Model():
114
  instyle = None
115
  h, w, scale = 0, 0, 0
116
  if paras is not None:
117
- h,w,top,bottom,left,right,scale = paras
118
  H, W = int(bottom-top), int(right-left)
119
  # for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
120
  kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
@@ -129,11 +110,11 @@ class Model():
129
  I = self.transform(I).unsqueeze(dim=0).to(self.device)
130
  instyle = self.pspencoder(I)
131
  instyle = self.vtoonify.zplus2wplus(instyle)
132
- message = 'Successfully rescale the frame to (%d, %d)'%(bottom-top, right-left)
133
  else:
134
- frame = np.zeros((256,256,3), np.uint8)
135
  else:
136
- frame = np.zeros((256,256,3), np.uint8)
137
  if return_para:
138
  return frame, instyle, message, w, h, top, bottom, left, right, scale
139
  return frame, instyle, message
@@ -142,21 +123,21 @@ class Model():
142
  def detect_and_align_image(self, image: str, top: int, bottom: int, left: int, right: int
143
  ) -> tuple[np.ndarray, torch.Tensor, str]:
144
  if image is None:
145
- return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load empty file.'
146
  frame = cv2.imread(image)
147
  if frame is None:
148
- return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load the image.'
149
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
150
  return self.detect_and_align(frame, top, bottom, left, right)
151
 
152
  def detect_and_align_video(self, video: str, top: int, bottom: int, left: int, right: int
153
  ) -> tuple[np.ndarray, torch.Tensor, str]:
154
  if video is None:
155
- return np.zeros((256,256,3), np.uint8), None, 'Error: fail to load empty file.'
156
  video_cap = cv2.VideoCapture(video)
157
  if video_cap.get(7) == 0:
158
  video_cap.release()
159
- return np.zeros((256,256,3), np.uint8), torch.zeros(1,18,512).to(self.device), 'Error: fail to load the video.'
160
  success, frame = video_cap.read()
161
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
162
  video_cap.release()
@@ -166,11 +147,11 @@ class Model():
166
  def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple[np.ndarray, str]:
167
  #print(style_type + ' ' + self.style_name)
168
  if instyle is None or aligned_face is None:
169
- return np.zeros((256,256,3), np.uint8), 'Opps, something wrong with the input. Please go to Step 2 and Rescale Image/First Frame again.'
170
  if self.style_name != style_type:
171
  exstyle, _ = self.load_model(style_type)
172
  if exstyle is None:
173
- return np.zeros((256,256,3), np.uint8), 'Opps, something wrong with the style type. Please go to Step 1 and load model again.'
174
  with torch.no_grad():
175
  if self.color_transfer:
176
  s_w = exstyle
@@ -182,17 +163,13 @@ class Model():
182
  x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
183
  scale_factor=0.5, recompute_scale_factor=False).detach()
184
  inputs = torch.cat((x, x_p/16.), dim=1)
185
- y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s = style_degree)
186
  y_tilde = torch.clamp(y_tilde, -1, 1)
187
- print('*** Toonify %dx%d image with style of %s'%(y_tilde.shape[2], y_tilde.shape[3], style_type))
188
- return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image with style of %s'%(self.style_name)
189
-
190
 
 
191
 
192
- model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
193
-
194
- from fastapi.responses import StreamingResponse
195
- from io import BytesIO
196
 
197
  @app.post("/upload/")
198
  async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
@@ -216,6 +193,7 @@ async def process_image(file: UploadFile = File(...), top: int = Form(...), bott
216
 
217
  app.mount("/", StaticFiles(directory="AB", html=True), name="static")
218
 
 
219
  @app.get("/")
220
  def index() -> FileResponse:
221
  return FileResponse(path="/app/AB/index.html", media_type="text/html")
 
1
  from __future__ import annotations
 
 
 
2
  from fastapi import FastAPI, File, UploadFile, Form
3
+ from fastapi.responses import StreamingResponse
4
+ from fastapi.staticfiles import StaticFiles
5
  import torch
6
  import shutil
7
  import cv2
 
10
  from torchvision import transforms
11
  import torch.nn.functional as F
12
  from vtoonify_model import Model # Importing the Model class from vtoonify_model.py
13
+ from util import load_psp_standalone, get_video_crop_parameter, tensor2cv2, align_face
 
 
 
 
 
 
 
 
 
 
 
14
  from model.vtoonify import VToonify
15
  from model.bisenet.model import BiSeNet
 
 
 
 
16
  import huggingface_hub
17
  import os
18
+ from io import BytesIO
19
 
20
  app = FastAPI()
 
21
 
22
  MODEL_REPO = 'PKUWilliamYang/VToonify'
23
 
24
+ class Model:
25
  def __init__(self, device):
26
  super().__init__()
27
 
 
36
  self.pspencoder = self._load_encoder()
37
  self.transform = transforms.Compose([
38
  transforms.ToTensor(),
39
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
40
+ ])
41
 
42
  self.vtoonify, self.exstyle = self._load_default_model()
43
  self.color_transfer = False
44
  self.style_name = 'cartoon1'
45
  self.video_limit_cpu = 100
46
  self.video_limit_gpu = 300
47
+
48
+ def _create_dlib_landmark_model(self):
49
+ return dlib.shape_predictor(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/shape_predictor_68_face_landmarks.dat'))
 
 
50
 
51
  def _create_parsing_model(self):
52
  parsingpredictor = BiSeNet(n_classes=19)
 
56
  return parsingpredictor
57
 
58
  def _load_encoder(self) -> nn.Module:
59
+ style_encoder_path = huggingface_hub.hf_hub_download(MODEL_REPO, 'models/encoder.pt')
60
  return load_psp_standalone(style_encoder_path, self.device)
61
 
62
  def _load_default_model(self) -> tuple[torch.Tensor, str]:
63
+ vtoonify = VToonify(backbone='dualstylegan')
64
  vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO,
65
  'models/vtoonify_d_cartoon/vtoonify_s026_d0.5.pt'),
66
  map_location=lambda storage, loc: storage)['g_ema'])
67
  vtoonify.to(self.device)
68
+ tmp = np.load(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/vtoonify_d_cartoon/exstyle_code.npy'), allow_pickle=True).item()
69
  exstyle = torch.tensor(tmp[list(tmp.keys())[26]]).to(self.device)
70
  with torch.no_grad():
71
  exstyle = vtoonify.zplus2wplus(exstyle)
 
80
  return None, 'Oops, wrong Style Type. Please select a valid model.'
81
  self.style_name = style_type
82
  model_path, ind = self.style_types[style_type]
83
+ style_path = os.path.join('models', os.path.dirname(model_path), 'exstyle_code.npy')
84
+ self.vtoonify.load_state_dict(torch.load(huggingface_hub.hf_hub_download(MODEL_REPO, 'models/' + model_path),
85
  map_location=lambda storage, loc: storage)['g_ema'])
86
  tmp = np.load(huggingface_hub.hf_hub_download(MODEL_REPO, style_path), allow_pickle=True).item()
87
  exstyle = torch.tensor(tmp[list(tmp.keys())[ind]]).to(self.device)
88
  with torch.no_grad():
89
  exstyle = self.vtoonify.zplus2wplus(exstyle)
90
+ return exstyle, 'Model of %s loaded.' % (style_type)
91
 
92
  def detect_and_align(self, frame, top, bottom, left, right, return_para=False):
93
  message = 'Error: no face detected! Please retry or change the photo.'
 
95
  instyle = None
96
  h, w, scale = 0, 0, 0
97
  if paras is not None:
98
+ h, w, top, bottom, left, right, scale = paras
99
  H, W = int(bottom-top), int(right-left)
100
  # for HR image, we apply gaussian blur to it to avoid over-sharp stylization results
101
  kernel_1d = np.array([[0.125],[0.375],[0.375],[0.125]])
 
110
  I = self.transform(I).unsqueeze(dim=0).to(self.device)
111
  instyle = self.pspencoder(I)
112
  instyle = self.vtoonify.zplus2wplus(instyle)
113
+ message = 'Successfully rescale the frame to (%d, %d)' % (bottom-top, right-left)
114
  else:
115
+ frame = np.zeros((256, 256, 3), np.uint8)
116
  else:
117
+ frame = np.zeros((256, 256, 3), np.uint8)
118
  if return_para:
119
  return frame, instyle, message, w, h, top, bottom, left, right, scale
120
  return frame, instyle, message
 
123
  def detect_and_align_image(self, image: str, top: int, bottom: int, left: int, right: int
124
  ) -> tuple[np.ndarray, torch.Tensor, str]:
125
  if image is None:
126
+ return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load empty file.'
127
  frame = cv2.imread(image)
128
  if frame is None:
129
+ return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load the image.'
130
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
131
  return self.detect_and_align(frame, top, bottom, left, right)
132
 
133
  def detect_and_align_video(self, video: str, top: int, bottom: int, left: int, right: int
134
  ) -> tuple[np.ndarray, torch.Tensor, str]:
135
  if video is None:
136
+ return np.zeros((256, 256, 3), np.uint8), None, 'Error: fail to load empty file.'
137
  video_cap = cv2.VideoCapture(video)
138
  if video_cap.get(7) == 0:
139
  video_cap.release()
140
+ return np.zeros((256, 256, 3), np.uint8), torch.zeros(1, 18, 512).to(self.device), 'Error: fail to load the video.'
141
  success, frame = video_cap.read()
142
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
143
  video_cap.release()
 
147
  def image_toonify(self, aligned_face: np.ndarray, instyle: torch.Tensor, exstyle: torch.Tensor, style_degree: float, style_type: str) -> tuple[np.ndarray, str]:
148
  #print(style_type + ' ' + self.style_name)
149
  if instyle is None or aligned_face is None:
150
+ return np.zeros((256, 256, 3), np.uint8), 'Opps, something wrong with the input. Please go to Step 2 and Rescale Image/First Frame again.'
151
  if self.style_name != style_type:
152
  exstyle, _ = self.load_model(style_type)
153
  if exstyle is None:
154
+ return np.zeros((256, 256, 3), np.uint8), 'Opps, something wrong with the style type. Please go to Step 1 and load model again.'
155
  with torch.no_grad():
156
  if self.color_transfer:
157
  s_w = exstyle
 
163
  x_p = F.interpolate(self.parsingpredictor(2*(F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)))[0],
164
  scale_factor=0.5, recompute_scale_factor=False).detach()
165
  inputs = torch.cat((x, x_p/16.), dim=1)
166
+ y_tilde = self.vtoonify(inputs, s_w.repeat(inputs.size(0), 1, 1), d_s=style_degree)
167
  y_tilde = torch.clamp(y_tilde, -1, 1)
168
+ print('*** Toonify %dx%d image with style of %s' % (y_tilde.shape[2], y_tilde.shape[3], style_type))
169
+ return ((y_tilde[0].cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8), 'Successfully toonify the image with style of %s' % (self.style_name)
 
170
 
171
+ model = Model(device='cuda' if torch.cuda.is_available() else 'cpu')
172
 
 
 
 
 
173
 
174
  @app.post("/upload/")
175
  async def process_image(file: UploadFile = File(...), top: int = Form(...), bottom: int = Form(...), left: int = Form(...), right: int = Form(...)):
 
193
 
194
  app.mount("/", StaticFiles(directory="AB", html=True), name="static")
195
 
196
+
197
  @app.get("/")
198
  def index() -> FileResponse:
199
  return FileResponse(path="/app/AB/index.html", media_type="text/html")