Spaces:
Runtime error
Runtime error
import os | |
import cv2 | |
import glob | |
import time | |
import torch | |
import shutil | |
import argparse | |
import platform | |
import datetime | |
import subprocess | |
import insightface | |
import onnxruntime | |
import numpy as np | |
import gradio as gr | |
import threading | |
import queue | |
from tqdm import tqdm | |
import concurrent.futures | |
from moviepy.editor import VideoFileClip | |
from PIL import Image | |
import io | |
from face_swapper import Inswapper, paste_to_whole | |
from face_analyser import detect_conditions, get_analysed_data, swap_options_list | |
from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list | |
from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations | |
from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid | |
## ------------------------------ USER ARGS ------------------------------ | |
parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper") | |
parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd()) | |
parser.add_argument("--batch_size", help="Gpu batch size", default=32) | |
parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False) | |
parser.add_argument( | |
"--colab", action="store_true", help="Enable colab mode", default=False | |
) | |
user_args = parser.parse_args() | |
## ------------------------------ DEFAULTS ------------------------------ | |
USE_COLAB = user_args.colab | |
USE_CUDA = user_args.cuda | |
DEF_OUTPUT_PATH = user_args.out_dir | |
BATCH_SIZE = int(user_args.batch_size) | |
WORKSPACE = None | |
OUTPUT_FILE = None | |
CURRENT_FRAME = None | |
STREAMER = None | |
DETECT_CONDITION = "best detection" | |
DETECT_SIZE = 640 | |
DETECT_THRESH = 0.6 | |
NUM_OF_SRC_SPECIFIC = 10 | |
MASK_INCLUDE = [ | |
"Skin", | |
"R-Eyebrow", | |
"L-Eyebrow", | |
"L-Eye", | |
"R-Eye", | |
"Nose", | |
"Mouth", | |
"L-Lip", | |
"U-Lip" | |
] | |
MASK_SOFT_KERNEL = 17 | |
MASK_SOFT_ITERATIONS = 10 | |
MASK_BLUR_AMOUNT = 0.1 | |
MASK_ERODE_AMOUNT = 0.15 | |
FACE_SWAPPER = None | |
FACE_ANALYSER = None | |
FACE_ENHANCER = None | |
FACE_PARSER = None | |
FACE_ENHANCER_LIST = ["NONE"] | |
FACE_ENHANCER_LIST.extend(get_available_enhancer_names()) | |
FACE_ENHANCER_LIST.extend(cv2_interpolations) | |
## ------------------------------ SET EXECUTION PROVIDER ------------------------------ | |
# Note: Non CUDA users may change settings here | |
PROVIDER = ["CPUExecutionProvider"] | |
if USE_CUDA: | |
available_providers = onnxruntime.get_available_providers() | |
if "CUDAExecutionProvider" in available_providers: | |
print("\n********** Running on CUDA **********\n") | |
PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
else: | |
USE_CUDA = False | |
print("\n********** CUDA unavailable running on CPU **********\n") | |
else: | |
USE_CUDA = False | |
print("\n********** Running on CPU **********\n") | |
device = "cuda" if USE_CUDA else "cpu" | |
EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None | |
## ------------------------------ LOAD MODELS ------------------------------ | |
def load_face_analyser_model(name="buffalo_l"): | |
global FACE_ANALYSER | |
if FACE_ANALYSER is None: | |
FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER) | |
FACE_ANALYSER.prepare( | |
ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH | |
) | |
def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"): | |
global FACE_SWAPPER | |
if FACE_SWAPPER is None: | |
batch = int(BATCH_SIZE) if device == "cuda" else 1 | |
FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER) | |
def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"): | |
global FACE_PARSER | |
if FACE_PARSER is None: | |
FACE_PARSER = init_parsing_model(path, device=device) | |
load_face_analyser_model() | |
load_face_swapper_model() | |
## ------------------------------ MAIN PROCESS ------------------------------ | |
def process( | |
input_type, | |
image_path, | |
video_path, | |
directory_path, | |
source_path, | |
output_path, | |
output_name, | |
keep_output_sequence, | |
condition, | |
age, | |
distance, | |
face_enhancer_name, | |
enable_face_parser, | |
mask_includes, | |
mask_soft_kernel, | |
mask_soft_iterations, | |
blur_amount, | |
erode_amount, | |
face_scale, | |
enable_laplacian_blend, | |
crop_top, | |
crop_bott, | |
crop_left, | |
crop_right, | |
*specifics, | |
): | |
global WORKSPACE | |
global OUTPUT_FILE | |
global PREVIEW | |
WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None | |
## ------------------------------ GUI UPDATE FUNC ------------------------------ | |
def ui_before(): | |
return ( | |
gr.update(visible=True, value=PREVIEW), | |
gr.update(interactive=False), | |
gr.update(interactive=False), | |
gr.update(visible=False), | |
) | |
def ui_after(): | |
return ( | |
gr.update(visible=True, value=PREVIEW), | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
gr.update(visible=False), | |
) | |
def ui_after_vid(): | |
return ( | |
gr.update(visible=False), | |
gr.update(interactive=True), | |
gr.update(interactive=True), | |
gr.update(value=OUTPUT_FILE, visible=True), | |
) | |
start_time = time.time() | |
total_exec_time = lambda start_time: divmod(time.time() - start_time, 60) | |
## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------ | |
load_face_analyser_model() | |
load_face_swapper_model() | |
if face_enhancer_name != "NONE": | |
if face_enhancer_name not in cv2_interpolations: | |
FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device) | |
else: | |
FACE_ENHANCER = None | |
if enable_face_parser: | |
load_face_parser_model() | |
includes = mask_regions_to_list(mask_includes) | |
specifics = list(specifics) | |
half = len(specifics) // 2 | |
sources = specifics[:half] | |
specifics = specifics[half:] | |
if crop_top > crop_bott: | |
crop_top, crop_bott = crop_bott, crop_top | |
if crop_left > crop_right: | |
crop_left, crop_right = crop_right, crop_left | |
crop_mask = (crop_top, 511-crop_bott, crop_left, 511-crop_right) | |
def swap_process(image_sequence): | |
## ------------------------------ CONTENT CHECK ------------------------------ | |
if condition != "Specific Face": | |
source_data = source_path, age | |
else: | |
source_data = ((sources, specifics), distance) | |
analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data( | |
FACE_ANALYSER, | |
image_sequence, | |
source_data, | |
swap_condition=condition, | |
detect_condition=DETECT_CONDITION, | |
scale=face_scale | |
) | |
## ------------------------------ SWAP FUNC ------------------------------ | |
preds = [] | |
matrs = [] | |
count = 0 | |
for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources): | |
preds.extend(batch_pred) | |
matrs.extend(batch_matr) | |
EMPTY_CACHE() | |
count += 1 | |
if USE_CUDA: | |
image_grid = create_image_grid(batch_pred, size=128) | |
## ------------------------------ FACE ENHANCEMENT ------------------------------ | |
generated_len = len(preds) | |
if face_enhancer_name != "NONE": | |
for idx, pred in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"): | |
enhancer_model, enhancer_model_runner = FACE_ENHANCER | |
pred = enhancer_model_runner(pred, enhancer_model) | |
preds[idx] = cv2.resize(pred, (512,512)) | |
EMPTY_CACHE() | |
## ------------------------------ FACE PARSING ------------------------------ | |
if enable_face_parser: | |
masks = [] | |
count = 0 | |
for batch_mask in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(mask_soft_iterations)): | |
masks.append(batch_mask) | |
EMPTY_CACHE() | |
count += 1 | |
if len(batch_mask) > 1: | |
image_grid = create_image_grid(batch_mask, size=128) | |
masks = np.concatenate(masks, axis=0) if len(masks) >= 1 else masks | |
else: | |
masks = [None] * generated_len | |
## ------------------------------ SPLIT LIST ------------------------------ | |
split_preds = split_list_by_lengths(preds, num_faces_per_frame) | |
del preds | |
split_matrs = split_list_by_lengths(matrs, num_faces_per_frame) | |
del matrs | |
split_masks = split_list_by_lengths(masks, num_faces_per_frame) | |
del masks | |
## ------------------------------ PASTE-BACK ------------------------------ | |
def post_process(frame_idx, frame_img, split_preds, split_matrs, split_masks, enable_laplacian_blend, crop_mask, blur_amount, erode_amount): | |
whole_img_path = frame_img | |
whole_img = cv2.imread(whole_img_path) | |
blend_method = 'laplacian' if enable_laplacian_blend else 'linear' | |
for p, m, mask in zip(split_preds[frame_idx], split_matrs[frame_idx], split_masks[frame_idx]): | |
p = cv2.resize(p, (512,512)) | |
mask = cv2.resize(mask, (512,512)) if mask is not None else None | |
m /= 0.25 | |
whole_img = paste_to_whole(p, whole_img, m, mask=mask, crop_mask=crop_mask, blend_method=blend_method, blur_amount=blur_amount, erode_amount=erode_amount) | |
cv2.imwrite(whole_img_path, whole_img) | |
def concurrent_post_process(image_sequence, *args): | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
futures = [] | |
for idx, frame_img in enumerate(image_sequence): | |
future = executor.submit(post_process, idx, frame_img, *args) | |
futures.append(future) | |
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"): | |
result = future.result() | |
concurrent_post_process( | |
image_sequence, | |
split_preds, | |
split_matrs, | |
split_masks, | |
enable_laplacian_blend, | |
crop_mask, | |
blur_amount, | |
erode_amount | |
) | |
## ------------------------------ Gardio API ------------------------------ | |
iface = gr.Interface( | |
fn=process_api, | |
inputs=[ | |
gr.Textbox(label="Source Image (base64)"), | |
gr.Textbox(label="Target Image (base64)") | |
], | |
outputs=gr.Textbox(label="Result Image (base64)"), | |
title="Face Swap API", | |
description="Submit two base64 encoded images to swap faces." | |
) | |
## ------------------------------ IMAGE ------------------------------ | |
if input_type == "Image": | |
target = cv2.imread(image_path) | |
output_file = os.path.join(output_path, output_name + ".png") | |
cv2.imwrite(output_file, target) | |
for info_update in swap_process([output_file]): | |
yield info_update | |
OUTPUT_FILE = output_file | |
WORKSPACE = output_path | |
PREVIEW = cv2.imread(output_file)[:, :, ::-1] | |
yield get_finsh_text(start_time), *ui_after() | |
## ------------------------------ VIDEO ------------------------------ | |
elif input_type == "Video": | |
temp_path = os.path.join(output_path, output_name, "sequence") | |
os.makedirs(temp_path, exist_ok=True) | |
yield "### \n 💽 Extracting video frames...", *ui_before() | |
image_sequence = [] | |
cap = cv2.VideoCapture(video_path) | |
curr_idx = 0 | |
while True: | |
ret, frame = cap.read() | |
if not ret:break | |
frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg") | |
cv2.imwrite(frame_path, frame) | |
image_sequence.append(frame_path) | |
curr_idx += 1 | |
cap.release() | |
cv2.destroyAllWindows() | |
for info_update in swap_process(image_sequence): | |
yield info_update | |
yield "### \n 🔗 Merging sequence...", *ui_before() | |
output_video_path = os.path.join(output_path, output_name + ".mp4") | |
merge_img_sequence_from_ref(video_path, image_sequence, output_video_path) | |
if os.path.exists(temp_path) and not keep_output_sequence: | |
yield "### \n 🚽 Removing temporary files...", *ui_before() | |
shutil.rmtree(temp_path) | |
WORKSPACE = output_path | |
OUTPUT_FILE = output_video_path | |
yield get_finsh_text(start_time), *ui_after_vid() | |
## ------------------------------ DIRECTORY ------------------------------ | |
elif input_type == "Directory": | |
extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"] | |
temp_path = os.path.join(output_path, output_name) | |
if os.path.exists(temp_path): | |
shutil.rmtree(temp_path) | |
os.mkdir(temp_path) | |
file_paths =[] | |
for file_path in glob.glob(os.path.join(directory_path, "*")): | |
if any(file_path.lower().endswith(ext) for ext in extensions): | |
img = cv2.imread(file_path) | |
new_file_path = os.path.join(temp_path, os.path.basename(file_path)) | |
cv2.imwrite(new_file_path, img) | |
file_paths.append(new_file_path) | |
for info_update in swap_process(file_paths): | |
yield info_update | |
WORKSPACE = temp_path | |
OUTPUT_FILE = file_paths[-1] | |
## ------------------------------ STREAM ------------------------------ | |
elif input_type == "Stream": | |
pass | |
## ------------------------------ GRADIO FUNC ------------------------------ | |
def analyse_settings_changed(detect_condition, detection_size, detection_threshold): | |
global FACE_ANALYSER | |
global DETECT_CONDITION | |
DETECT_CONDITION = detect_condition | |
FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER) | |
FACE_ANALYSER.prepare( | |
ctx_id=0, | |
det_size=(int(detection_size), int(detection_size)), | |
det_thresh=float(detection_threshold), | |
) | |
def decode_base64_image(base64_string): | |
img_data = base64.b64decode(base64_string) | |
img = Image.open(io.BytesIO(img_data)) | |
return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
def process_api(source_base64, target_base64): | |
source_image = decode_base64_image(source_base64) | |
target_image = decode_base64_image(target_base64) | |
temp_source_path = "temp_source.jpg" | |
temp_target_path = "temp_target.jpg" | |
cv2.imwrite(temp_source_path, source_image) | |
cv2.imwrite(temp_target_path, target_image) | |
result = process( | |
input_type="Image", | |
image_path=temp_target_path, | |
video_path=None, | |
directory_path=None, | |
source_path=temp_source_path, | |
output_path="output", | |
output_name="result", | |
keep_output_sequence=False, | |
condition="First found face", | |
age=None, | |
distance=None, | |
face_enhancer_name="NONE", | |
enable_face_parser=False, | |
mask_includes=MASK_INCLUDE, | |
mask_soft_kernel=MASK_SOFT_KERNEL, | |
mask_soft_iterations=MASK_SOFT_ITERATIONS, | |
blur_amount=MASK_BLUR_AMOUNT, | |
erode_amount=MASK_ERODE_AMOUNT, | |
face_scale=1.0, | |
enable_laplacian_blend=True, | |
crop_top, | |
crop_bott, | |
crop_left, | |
crop_right, | |
) | |
os.remove(temp_source_path) | |
os.remove(temp_target_path) | |
result_image = cv2.imread("output/result.png") | |
_, buffer = cv2.imencode('.jpg', result_image) | |
result_base64 = base64.b64encode(buffer).decode('utf-8') | |
return result_base64 | |
def stop_running(): | |
global STREAMER | |
if hasattr(STREAMER, "stop"): | |
STREAMER.stop() | |
STREAMER = None | |
return "Cancelled" | |
def slider_changed(show_frame, video_path, frame_index): | |
if not show_frame: | |
return None, None | |
if video_path is None: | |
return None, None | |
clip = VideoFileClip(video_path) | |
frame = clip.get_frame(frame_index / clip.fps) | |
frame_array = np.array(frame) | |
clip.close() | |
return gr.Image.update(value=frame_array, visible=True), gr.Video.update( | |
visible=False | |
) | |
def trim_and_reload(video_path, output_path, output_name, start_frame, stop_frame): | |
try: | |
output_path = os.path.join(output_path, output_name) | |
trimmed_video = trim_video(video_path, output_path, start_frame, stop_frame) | |
except Exception as e: | |
print(e) | |
if __name__ == "__main__": | |
if USE_COLAB: | |
print("Running in colab mode") | |
iface.queue(concurrency_count=2, max_size=20).launch(share=USE_COLAB) | |