diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000000000000000000000000000000000000..b88a39dcf36b90aae0763caaee5e3afe0cc4159f --- /dev/null +++ b/.editorconfig @@ -0,0 +1,8 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +indent_size = 4 +indent_style = tab +trim_trailing_whitespace = true diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000000000000000000000000000000000..5419fcc6e770a1e6a48abfb9711919c2705c831b --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] +select = E22, E23, E24, E27, E3, E4, E7, F, I1, I2 +per-file-ignores = facefusion.py:E402, install.py:E402 +plugins = flake8-import-order +application_import_names = facefusion +import-order-style = pycharm diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000000000000000000000000000000000000..0ea4213bb3db395c9cc8bd586e0f268727ed3c0b --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,3 @@ +OpenRAIL-AS license + +Copyright (c) 2025 Henry Ruhs diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md new file mode 100644 index 0000000000000000000000000000000000000000..5eb32f4f2edbe7f4bfa5b0be0d8c5f5ad4b5da3e --- /dev/null +++ b/ORIGINAL_README.md @@ -0,0 +1,61 @@ +FaceFusion +========== + +> Industry leading face manipulation platform. + +[![Build Status](https://img.shields.io/github/actions/workflow/status/facefusion/facefusion/ci.yml.svg?branch=master)](https://github.com/facefusion/facefusion/actions?query=workflow:ci) +[![Coverage Status](https://img.shields.io/coveralls/facefusion/facefusion.svg)](https://coveralls.io/r/facefusion/facefusion) +![License](https://img.shields.io/badge/license-OpenRAIL--AS-green) + + +Preview +------- + +![Preview](https://raw.githubusercontent.com/facefusion/facefusion/master/.github/preview.png?sanitize=true) + + +Installation +------------ + +Be aware, the [installation](https://docs.facefusion.io/installation) needs technical skills and is not recommended for beginners. In case you are not comfortable using a terminal, our [Windows Installer](http://windows-installer.facefusion.io) and [macOS Installer](http://macos-installer.facefusion.io) get you started. + + +Usage +----- + +Run the command: + +``` +python facefusion.py [commands] [options] + +options: + -h, --help show this help message and exit + -v, --version show program's version number and exit + +commands: + run run the program + headless-run run the program in headless mode + batch-run run the program in batch mode + force-download force automate downloads and exit + benchmark benchmark the program + job-list list jobs by status + job-create create a drafted job + job-submit submit a drafted job to become a queued job + job-submit-all submit all drafted jobs to become a queued jobs + job-delete delete a drafted, queued, failed or completed job + job-delete-all delete all drafted, queued, failed and completed jobs + job-add-step add a step to a drafted job + job-remix-step remix a previous step from a drafted job + job-insert-step insert a step to a drafted job + job-remove-step remove a step from a drafted job + job-run run a queued job + job-run-all run all queued jobs + job-retry retry a failed job + job-retry-all retry all failed jobs +``` + + +Documentation +------------- + +Read the [documentation](https://docs.facefusion.io) for a deep dive. diff --git a/facefusion.ico b/facefusion.ico new file mode 100644 index 0000000000000000000000000000000000000000..7dd4da75931ebe108efbb84b28d1585d04495533 Binary files /dev/null and b/facefusion.ico differ diff --git a/facefusion.ini b/facefusion.ini new file mode 100644 index 0000000000000000000000000000000000000000..37b772af31b5c318b9031a995be127d39ee1fe19 --- /dev/null +++ b/facefusion.ini @@ -0,0 +1,123 @@ +[paths] +temp_path = +jobs_path = +source_paths = +target_path = +output_path = + +[patterns] +source_pattern = +target_pattern = +output_pattern = + +[face_detector] +face_detector_model = +face_detector_size = +face_detector_angles = +face_detector_score = + +[face_landmarker] +face_landmarker_model = +face_landmarker_score = + +[face_selector] +face_selector_mode = +face_selector_order = +face_selector_age_start = +face_selector_age_end = +face_selector_gender = +face_selector_race = +reference_face_position = +reference_face_distance = +reference_frame_number = + +[face_masker] +face_occluder_model = +face_parser_model = +face_mask_types = +face_mask_areas = +face_mask_regions = +face_mask_blur = +face_mask_padding = + +[frame_extraction] +trim_frame_start = +trim_frame_end = +temp_frame_format = +keep_temp = + +[output_creation] +output_image_quality = +output_image_resolution = +output_audio_encoder = +output_audio_quality = +output_audio_volume = +output_video_encoder = +output_video_preset = +output_video_quality = +output_video_resolution = +output_video_fps = + +[processors] +processors = +age_modifier_model = +age_modifier_direction = +deep_swapper_model = +deep_swapper_morph = +expression_restorer_model = +expression_restorer_factor = +face_debugger_items = +face_editor_model = +face_editor_eyebrow_direction = +face_editor_eye_gaze_horizontal = +face_editor_eye_gaze_vertical = +face_editor_eye_open_ratio = +face_editor_lip_open_ratio = +face_editor_mouth_grim = +face_editor_mouth_pout = +face_editor_mouth_purse = +face_editor_mouth_smile = +face_editor_mouth_position_horizontal = +face_editor_mouth_position_vertical = +face_editor_head_pitch = +face_editor_head_yaw = +face_editor_head_roll = +face_enhancer_model = +face_enhancer_blend = +face_enhancer_weight = +face_swapper_model = +face_swapper_pixel_boost = +frame_colorizer_model = +frame_colorizer_size = +frame_colorizer_blend = +frame_enhancer_model = +frame_enhancer_blend = +lip_syncer_model = +lip_syncer_weight = + +[uis] +open_browser = +ui_layouts = +ui_workflow = + +[download] +download_providers = +download_scope = + +[benchmark] +benchmark_resolutions = +benchmark_cycle_count = + +[execution] +execution_device_id = +execution_providers = +execution_thread_count = +execution_queue_count = + +[memory] +video_memory_strategy = +system_memory_limit = + +[misc] +log_level = +halt_on_error = diff --git a/facefusion.py b/facefusion.py new file mode 100644 index 0000000000000000000000000000000000000000..98a865c718cf12234377c4c662743bea1b90c9c5 --- /dev/null +++ b/facefusion.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +import os + +os.environ['OMP_NUM_THREADS'] = '1' + +from facefusion import core + +if __name__ == '__main__': + core.cli() diff --git a/facefusion/__init__.py b/facefusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/facefusion/app_context.py b/facefusion/app_context.py new file mode 100644 index 0000000000000000000000000000000000000000..d54f961ef51167edd8a4c02e6fa8a625cabb588f --- /dev/null +++ b/facefusion/app_context.py @@ -0,0 +1,16 @@ +import os +import sys + +from facefusion.types import AppContext + + +def detect_app_context() -> AppContext: + frame = sys._getframe(1) + + while frame: + if os.path.join('facefusion', 'jobs') in frame.f_code.co_filename: + return 'cli' + if os.path.join('facefusion', 'uis') in frame.f_code.co_filename: + return 'ui' + frame = frame.f_back + return 'cli' diff --git a/facefusion/args.py b/facefusion/args.py new file mode 100644 index 0000000000000000000000000000000000000000..ba5c511ae32cd5eb4c12c8df5d1d99e3b2fb7a7f --- /dev/null +++ b/facefusion/args.py @@ -0,0 +1,140 @@ +from facefusion import state_manager +from facefusion.filesystem import get_file_name, is_image, is_video, resolve_file_paths +from facefusion.jobs import job_store +from facefusion.normalizer import normalize_fps, normalize_padding +from facefusion.processors.core import get_processors_modules +from facefusion.types import ApplyStateItem, Args +from facefusion.vision import create_image_resolutions, create_video_resolutions, detect_image_resolution, detect_video_fps, detect_video_resolution, pack_resolution + + +def reduce_step_args(args : Args) -> Args: + step_args =\ + { + key: args[key] for key in args if key in job_store.get_step_keys() + } + return step_args + + +def reduce_job_args(args : Args) -> Args: + job_args =\ + { + key: args[key] for key in args if key in job_store.get_job_keys() + } + return job_args + + +def collect_step_args() -> Args: + step_args =\ + { + key: state_manager.get_item(key) for key in job_store.get_step_keys() #type:ignore[arg-type] + } + return step_args + + +def collect_job_args() -> Args: + job_args =\ + { + key: state_manager.get_item(key) for key in job_store.get_job_keys() #type:ignore[arg-type] + } + return job_args + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + # general + apply_state_item('command', args.get('command')) + # paths + apply_state_item('temp_path', args.get('temp_path')) + apply_state_item('jobs_path', args.get('jobs_path')) + apply_state_item('source_paths', args.get('source_paths')) + apply_state_item('target_path', args.get('target_path')) + apply_state_item('output_path', args.get('output_path')) + # patterns + apply_state_item('source_pattern', args.get('source_pattern')) + apply_state_item('target_pattern', args.get('target_pattern')) + apply_state_item('output_pattern', args.get('output_pattern')) + # face detector + apply_state_item('face_detector_model', args.get('face_detector_model')) + apply_state_item('face_detector_size', args.get('face_detector_size')) + apply_state_item('face_detector_angles', args.get('face_detector_angles')) + apply_state_item('face_detector_score', args.get('face_detector_score')) + # face landmarker + apply_state_item('face_landmarker_model', args.get('face_landmarker_model')) + apply_state_item('face_landmarker_score', args.get('face_landmarker_score')) + # face selector + apply_state_item('face_selector_mode', args.get('face_selector_mode')) + apply_state_item('face_selector_order', args.get('face_selector_order')) + apply_state_item('face_selector_age_start', args.get('face_selector_age_start')) + apply_state_item('face_selector_age_end', args.get('face_selector_age_end')) + apply_state_item('face_selector_gender', args.get('face_selector_gender')) + apply_state_item('face_selector_race', args.get('face_selector_race')) + apply_state_item('reference_face_position', args.get('reference_face_position')) + apply_state_item('reference_face_distance', args.get('reference_face_distance')) + apply_state_item('reference_frame_number', args.get('reference_frame_number')) + # face masker + apply_state_item('face_occluder_model', args.get('face_occluder_model')) + apply_state_item('face_parser_model', args.get('face_parser_model')) + apply_state_item('face_mask_types', args.get('face_mask_types')) + apply_state_item('face_mask_areas', args.get('face_mask_areas')) + apply_state_item('face_mask_regions', args.get('face_mask_regions')) + apply_state_item('face_mask_blur', args.get('face_mask_blur')) + apply_state_item('face_mask_padding', normalize_padding(args.get('face_mask_padding'))) + # frame extraction + apply_state_item('trim_frame_start', args.get('trim_frame_start')) + apply_state_item('trim_frame_end', args.get('trim_frame_end')) + apply_state_item('temp_frame_format', args.get('temp_frame_format')) + apply_state_item('keep_temp', args.get('keep_temp')) + # output creation + apply_state_item('output_image_quality', args.get('output_image_quality')) + if is_image(args.get('target_path')): + output_image_resolution = detect_image_resolution(args.get('target_path')) + output_image_resolutions = create_image_resolutions(output_image_resolution) + if args.get('output_image_resolution') in output_image_resolutions: + apply_state_item('output_image_resolution', args.get('output_image_resolution')) + else: + apply_state_item('output_image_resolution', pack_resolution(output_image_resolution)) + apply_state_item('output_audio_encoder', args.get('output_audio_encoder')) + apply_state_item('output_audio_quality', args.get('output_audio_quality')) + apply_state_item('output_audio_volume', args.get('output_audio_volume')) + apply_state_item('output_video_encoder', args.get('output_video_encoder')) + apply_state_item('output_video_preset', args.get('output_video_preset')) + apply_state_item('output_video_quality', args.get('output_video_quality')) + if is_video(args.get('target_path')): + output_video_resolution = detect_video_resolution(args.get('target_path')) + output_video_resolutions = create_video_resolutions(output_video_resolution) + if args.get('output_video_resolution') in output_video_resolutions: + apply_state_item('output_video_resolution', args.get('output_video_resolution')) + else: + apply_state_item('output_video_resolution', pack_resolution(output_video_resolution)) + if args.get('output_video_fps') or is_video(args.get('target_path')): + output_video_fps = normalize_fps(args.get('output_video_fps')) or detect_video_fps(args.get('target_path')) + apply_state_item('output_video_fps', output_video_fps) + # processors + available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] + apply_state_item('processors', args.get('processors')) + for processor_module in get_processors_modules(available_processors): + processor_module.apply_args(args, apply_state_item) + # uis + apply_state_item('open_browser', args.get('open_browser')) + apply_state_item('ui_layouts', args.get('ui_layouts')) + apply_state_item('ui_workflow', args.get('ui_workflow')) + # execution + apply_state_item('execution_device_id', args.get('execution_device_id')) + apply_state_item('execution_providers', args.get('execution_providers')) + apply_state_item('execution_thread_count', args.get('execution_thread_count')) + apply_state_item('execution_queue_count', args.get('execution_queue_count')) + # download + apply_state_item('download_providers', args.get('download_providers')) + apply_state_item('download_scope', args.get('download_scope')) + # benchmark + apply_state_item('benchmark_resolutions', args.get('benchmark_resolutions')) + apply_state_item('benchmark_cycle_count', args.get('benchmark_cycle_count')) + # memory + apply_state_item('video_memory_strategy', args.get('video_memory_strategy')) + apply_state_item('system_memory_limit', args.get('system_memory_limit')) + # misc + apply_state_item('log_level', args.get('log_level')) + apply_state_item('halt_on_error', args.get('halt_on_error')) + # jobs + apply_state_item('job_id', args.get('job_id')) + apply_state_item('job_status', args.get('job_status')) + apply_state_item('step_index', args.get('step_index')) diff --git a/facefusion/audio.py b/facefusion/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..43b72b1c12cb97d18a64528ab570d2f83e83d773 --- /dev/null +++ b/facefusion/audio.py @@ -0,0 +1,143 @@ +from functools import lru_cache +from typing import Any, List, Optional + +import numpy +import scipy +from numpy.typing import NDArray + +from facefusion.ffmpeg import read_audio_buffer +from facefusion.filesystem import is_audio +from facefusion.types import Audio, AudioFrame, Fps, Mel, MelFilterBank, Spectrogram +from facefusion.voice_extractor import batch_extract_voice + + +@lru_cache() +def read_static_audio(audio_path : str, fps : Fps) -> Optional[List[AudioFrame]]: + return read_audio(audio_path, fps) + + +def read_audio(audio_path : str, fps : Fps) -> Optional[List[AudioFrame]]: + audio_sample_rate = 48000 + audio_sample_size = 16 + audio_channel_total = 2 + + if is_audio(audio_path): + audio_buffer = read_audio_buffer(audio_path, audio_sample_rate, audio_sample_size, audio_channel_total) + audio = numpy.frombuffer(audio_buffer, dtype = numpy.int16).reshape(-1, 2) + audio = prepare_audio(audio) + spectrogram = create_spectrogram(audio) + audio_frames = extract_audio_frames(spectrogram, fps) + return audio_frames + return None + + +@lru_cache() +def read_static_voice(audio_path : str, fps : Fps) -> Optional[List[AudioFrame]]: + return read_voice(audio_path, fps) + + +def read_voice(audio_path : str, fps : Fps) -> Optional[List[AudioFrame]]: + voice_sample_rate = 48000 + voice_sample_size = 16 + voice_channel_total = 2 + voice_chunk_size = 240 * 1024 + voice_step_size = 180 * 1024 + + if is_audio(audio_path): + audio_buffer = read_audio_buffer(audio_path, voice_sample_rate, voice_sample_size, voice_channel_total) + audio = numpy.frombuffer(audio_buffer, dtype = numpy.int16).reshape(-1, 2) + audio = batch_extract_voice(audio, voice_chunk_size, voice_step_size) + audio = prepare_voice(audio) + spectrogram = create_spectrogram(audio) + audio_frames = extract_audio_frames(spectrogram, fps) + return audio_frames + return None + + +def get_audio_frame(audio_path : str, fps : Fps, frame_number : int = 0) -> Optional[AudioFrame]: + if is_audio(audio_path): + audio_frames = read_static_audio(audio_path, fps) + if frame_number in range(len(audio_frames)): + return audio_frames[frame_number] + return None + + +def extract_audio_frames(spectrogram : Spectrogram, fps : Fps) -> List[AudioFrame]: + audio_frames = [] + mel_filter_total = 80 + audio_step_size = 16 + indices = numpy.arange(0, spectrogram.shape[1], mel_filter_total / fps).astype(numpy.int16) + indices = indices[indices >= audio_step_size] + + for index in indices: + start = max(0, index - audio_step_size) + audio_frames.append(spectrogram[:, start:index]) + + return audio_frames + + +def get_voice_frame(audio_path : str, fps : Fps, frame_number : int = 0) -> Optional[AudioFrame]: + if is_audio(audio_path): + voice_frames = read_static_voice(audio_path, fps) + if frame_number in range(len(voice_frames)): + return voice_frames[frame_number] + return None + + +def create_empty_audio_frame() -> AudioFrame: + mel_filter_total = 80 + audio_step_size = 16 + audio_frame = numpy.zeros((mel_filter_total, audio_step_size)).astype(numpy.int16) + return audio_frame + + +def prepare_audio(audio : Audio) -> Audio: + if audio.ndim > 1: + audio = numpy.mean(audio, axis = 1) + audio = audio / numpy.max(numpy.abs(audio), axis = 0) + audio = scipy.signal.lfilter([ 1.0, -0.97 ], [ 1.0 ], audio) + return audio + + +def prepare_voice(audio : Audio) -> Audio: + audio_sample_rate = 48000 + audio_resample_rate = 16000 + audio_resample_factor = round(len(audio) * audio_resample_rate / audio_sample_rate) + audio = scipy.signal.resample(audio, audio_resample_factor) + audio = prepare_audio(audio) + return audio + + +def convert_hertz_to_mel(hertz : float) -> float: + return 2595 * numpy.log10(1 + hertz / 700) + + +def convert_mel_to_hertz(mel : Mel) -> NDArray[Any]: + return 700 * (10 ** (mel / 2595) - 1) + + +def create_mel_filter_bank() -> MelFilterBank: + audio_sample_rate = 16000 + audio_min_frequency = 55.0 + audio_max_frequency = 7600.0 + mel_filter_total = 80 + mel_bin_total = 800 + mel_filter_bank = numpy.zeros((mel_filter_total, mel_bin_total // 2 + 1)) + mel_frequency_range = numpy.linspace(convert_hertz_to_mel(audio_min_frequency), convert_hertz_to_mel(audio_max_frequency), mel_filter_total + 2) + indices = numpy.floor((mel_bin_total + 1) * convert_mel_to_hertz(mel_frequency_range) / audio_sample_rate).astype(numpy.int16) + + for index in range(mel_filter_total): + start = indices[index] + end = indices[index + 1] + mel_filter_bank[index, start:end] = scipy.signal.windows.triang(end - start) + + return mel_filter_bank + + +def create_spectrogram(audio : Audio) -> Spectrogram: + mel_bin_total = 800 + mel_bin_overlap = 600 + mel_filter_bank = create_mel_filter_bank() + spectrogram = scipy.signal.stft(audio, nperseg = mel_bin_total, nfft = mel_bin_total, noverlap = mel_bin_overlap)[2] + spectrogram = numpy.dot(mel_filter_bank, numpy.abs(spectrogram)) + return spectrogram diff --git a/facefusion/benchmarker.py b/facefusion/benchmarker.py new file mode 100644 index 0000000000000000000000000000000000000000..762fbb2c4d3f639ac6bf77d677423be889cf408e --- /dev/null +++ b/facefusion/benchmarker.py @@ -0,0 +1,106 @@ +import hashlib +import os +import statistics +import tempfile +from time import perf_counter +from typing import Generator, List + +import facefusion.choices +from facefusion import core, state_manager +from facefusion.cli_helper import render_table +from facefusion.download import conditional_download, resolve_download_url +from facefusion.filesystem import get_file_extension +from facefusion.types import BenchmarkCycleSet +from facefusion.vision import count_video_frame_total, detect_video_fps, detect_video_resolution, pack_resolution + + +def pre_check() -> bool: + conditional_download('.assets/examples', + [ + resolve_download_url('examples-3.0.0', 'source.jpg'), + resolve_download_url('examples-3.0.0', 'source.mp3'), + resolve_download_url('examples-3.0.0', 'target-240p.mp4'), + resolve_download_url('examples-3.0.0', 'target-360p.mp4'), + resolve_download_url('examples-3.0.0', 'target-540p.mp4'), + resolve_download_url('examples-3.0.0', 'target-720p.mp4'), + resolve_download_url('examples-3.0.0', 'target-1080p.mp4'), + resolve_download_url('examples-3.0.0', 'target-1440p.mp4'), + resolve_download_url('examples-3.0.0', 'target-2160p.mp4') + ]) + return True + + +def run() -> Generator[List[BenchmarkCycleSet], None, None]: + benchmark_resolutions = state_manager.get_item('benchmark_resolutions') + benchmark_cycle_count = state_manager.get_item('benchmark_cycle_count') + + state_manager.init_item('source_paths', [ '.assets/examples/source.jpg', '.assets/examples/source.mp3' ]) + state_manager.init_item('face_landmarker_score', 0) + state_manager.init_item('temp_frame_format', 'bmp') + state_manager.init_item('output_audio_volume', 0) + state_manager.init_item('output_video_preset', 'ultrafast') + state_manager.init_item('video_memory_strategy', 'tolerant') + + benchmarks = [] + target_paths = [facefusion.choices.benchmark_set.get(benchmark_resolution) for benchmark_resolution in benchmark_resolutions if benchmark_resolution in facefusion.choices.benchmark_set] + + for target_path in target_paths: + state_manager.set_item('target_path', target_path) + state_manager.set_item('output_path', suggest_output_path(state_manager.get_item('target_path'))) + benchmarks.append(cycle(benchmark_cycle_count)) + yield benchmarks + + +def cycle(cycle_count : int) -> BenchmarkCycleSet: + process_times = [] + video_frame_total = count_video_frame_total(state_manager.get_item('target_path')) + output_video_resolution = detect_video_resolution(state_manager.get_item('target_path')) + state_manager.set_item('output_video_resolution', pack_resolution(output_video_resolution)) + state_manager.set_item('output_video_fps', detect_video_fps(state_manager.get_item('target_path'))) + + core.conditional_process() + + for index in range(cycle_count): + start_time = perf_counter() + core.conditional_process() + end_time = perf_counter() + process_times.append(end_time - start_time) + + average_run = round(statistics.mean(process_times), 2) + fastest_run = round(min(process_times), 2) + slowest_run = round(max(process_times), 2) + relative_fps = round(video_frame_total * cycle_count / sum(process_times), 2) + + return\ + { + 'target_path': state_manager.get_item('target_path'), + 'cycle_count': cycle_count, + 'average_run': average_run, + 'fastest_run': fastest_run, + 'slowest_run': slowest_run, + 'relative_fps': relative_fps + } + + +def suggest_output_path(target_path : str) -> str: + target_file_extension = get_file_extension(target_path) + return os.path.join(tempfile.gettempdir(), hashlib.sha1().hexdigest()[:8] + target_file_extension) + + +def render() -> None: + benchmarks = [] + headers =\ + [ + 'target_path', + 'cycle_count', + 'average_run', + 'fastest_run', + 'slowest_run', + 'relative_fps' + ] + + for benchmark in run(): + benchmarks = benchmark + + contents = [ list(benchmark_set.values()) for benchmark_set in benchmarks ] + render_table(headers, contents) diff --git a/facefusion/choices.py b/facefusion/choices.py new file mode 100644 index 0000000000000000000000000000000000000000..c51463ed560efdfbd4d941c78a855a081e0d8b7b --- /dev/null +++ b/facefusion/choices.py @@ -0,0 +1,165 @@ +import logging +from typing import List, Sequence + +from facefusion.common_helper import create_float_range, create_int_range +from facefusion.types import Angle, AudioEncoder, AudioFormat, AudioTypeSet, BenchmarkResolution, BenchmarkSet, DownloadProvider, DownloadProviderSet, DownloadScope, EncoderSet, ExecutionProvider, ExecutionProviderSet, FaceDetectorModel, FaceDetectorSet, FaceLandmarkerModel, FaceMaskArea, FaceMaskAreaSet, FaceMaskRegion, FaceMaskRegionSet, FaceMaskType, FaceOccluderModel, FaceParserModel, FaceSelectorMode, FaceSelectorOrder, Gender, ImageFormat, ImageTypeSet, JobStatus, LogLevel, LogLevelSet, Race, Score, TempFrameFormat, UiWorkflow, VideoEncoder, VideoFormat, VideoMemoryStrategy, VideoPreset, VideoTypeSet, WebcamMode + +face_detector_set : FaceDetectorSet =\ +{ + 'many': [ '640x640' ], + 'retinaface': [ '160x160', '320x320', '480x480', '512x512', '640x640' ], + 'scrfd': [ '160x160', '320x320', '480x480', '512x512', '640x640' ], + 'yolo_face': [ '640x640' ] +} +face_detector_models : List[FaceDetectorModel] = list(face_detector_set.keys()) +face_landmarker_models : List[FaceLandmarkerModel] = [ 'many', '2dfan4', 'peppa_wutz' ] +face_selector_modes : List[FaceSelectorMode] = [ 'many', 'one', 'reference' ] +face_selector_orders : List[FaceSelectorOrder] = [ 'left-right', 'right-left', 'top-bottom', 'bottom-top', 'small-large', 'large-small', 'best-worst', 'worst-best' ] +face_selector_genders : List[Gender] = [ 'female', 'male' ] +face_selector_races : List[Race] = [ 'white', 'black', 'latino', 'asian', 'indian', 'arabic' ] +face_occluder_models : List[FaceOccluderModel] = [ 'xseg_1', 'xseg_2', 'xseg_3' ] +face_parser_models : List[FaceParserModel] = [ 'bisenet_resnet_18', 'bisenet_resnet_34' ] +face_mask_types : List[FaceMaskType] = [ 'box', 'occlusion', 'area', 'region' ] +face_mask_area_set : FaceMaskAreaSet =\ +{ + 'upper-face': [ 0, 1, 2, 31, 32, 33, 34, 35, 14, 15, 16, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17 ], + 'lower-face': [ 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 35, 34, 33, 32, 31 ], + 'mouth': [ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67 ] +} +face_mask_region_set : FaceMaskRegionSet =\ +{ + 'skin': 1, + 'left-eyebrow': 2, + 'right-eyebrow': 3, + 'left-eye': 4, + 'right-eye': 5, + 'glasses': 6, + 'nose': 10, + 'mouth': 11, + 'upper-lip': 12, + 'lower-lip': 13 +} +face_mask_areas : List[FaceMaskArea] = list(face_mask_area_set.keys()) +face_mask_regions : List[FaceMaskRegion] = list(face_mask_region_set.keys()) + +audio_type_set : AudioTypeSet =\ +{ + 'flac': 'audio/flac', + 'm4a': 'audio/mp4', + 'mp3': 'audio/mpeg', + 'ogg': 'audio/ogg', + 'opus': 'audio/opus', + 'wav': 'audio/x-wav' +} +image_type_set : ImageTypeSet =\ +{ + 'bmp': 'image/bmp', + 'jpeg': 'image/jpeg', + 'png': 'image/png', + 'tiff': 'image/tiff', + 'webp': 'image/webp' +} +video_type_set : VideoTypeSet =\ +{ + 'avi': 'video/x-msvideo', + 'm4v': 'video/mp4', + 'mkv': 'video/x-matroska', + 'mp4': 'video/mp4', + 'mov': 'video/quicktime', + 'webm': 'video/webm' +} +audio_formats : List[AudioFormat] = list(audio_type_set.keys()) +image_formats : List[ImageFormat] = list(image_type_set.keys()) +video_formats : List[VideoFormat] = list(video_type_set.keys()) +temp_frame_formats : List[TempFrameFormat] = [ 'bmp', 'jpeg', 'png', 'tiff' ] + +output_encoder_set : EncoderSet =\ +{ + 'audio': [ 'flac', 'aac', 'libmp3lame', 'libopus', 'libvorbis', 'pcm_s16le', 'pcm_s32le' ], + 'video': [ 'libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc', 'h264_amf', 'hevc_amf', 'h264_qsv', 'hevc_qsv', 'h264_videotoolbox', 'hevc_videotoolbox', 'rawvideo' ] +} +output_audio_encoders : List[AudioEncoder] = output_encoder_set.get('audio') +output_video_encoders : List[VideoEncoder] = output_encoder_set.get('video') +output_video_presets : List[VideoPreset] = [ 'ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow' ] + +image_template_sizes : List[float] = [ 0.25, 0.5, 0.75, 1, 1.5, 2, 2.5, 3, 3.5, 4 ] +video_template_sizes : List[int] = [ 240, 360, 480, 540, 720, 1080, 1440, 2160, 4320 ] + +benchmark_set : BenchmarkSet =\ +{ + '240p': '.assets/examples/target-240p.mp4', + '360p': '.assets/examples/target-360p.mp4', + '540p': '.assets/examples/target-540p.mp4', + '720p': '.assets/examples/target-720p.mp4', + '1080p': '.assets/examples/target-1080p.mp4', + '1440p': '.assets/examples/target-1440p.mp4', + '2160p': '.assets/examples/target-2160p.mp4' +} +benchmark_resolutions : List[BenchmarkResolution] = list(benchmark_set.keys()) + +webcam_modes : List[WebcamMode] = [ 'inline', 'udp', 'v4l2' ] +webcam_resolutions : List[str] = [ '320x240', '640x480', '800x600', '1024x768', '1280x720', '1280x960', '1920x1080', '2560x1440', '3840x2160' ] + +execution_provider_set : ExecutionProviderSet =\ +{ + 'cuda': 'CUDAExecutionProvider', + 'tensorrt': 'TensorrtExecutionProvider', + 'directml': 'DmlExecutionProvider', + 'rocm': 'ROCMExecutionProvider', + 'openvino': 'OpenVINOExecutionProvider', + 'coreml': 'CoreMLExecutionProvider', + 'cpu': 'CPUExecutionProvider' +} +execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys()) +download_provider_set : DownloadProviderSet =\ +{ + 'github': + { + 'urls': + [ + 'https://github.com' + ], + 'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}' + }, + 'huggingface': + { + 'urls': + [ + 'https://huggingface.co', + 'https://hf-mirror.com' + ], + 'path': '/facefusion/{base_name}/resolve/main/{file_name}' + } +} +download_providers : List[DownloadProvider] = list(download_provider_set.keys()) +download_scopes : List[DownloadScope] = [ 'lite', 'full' ] + +video_memory_strategies : List[VideoMemoryStrategy] = [ 'strict', 'moderate', 'tolerant' ] + +log_level_set : LogLevelSet =\ +{ + 'error': logging.ERROR, + 'warn': logging.WARNING, + 'info': logging.INFO, + 'debug': logging.DEBUG +} +log_levels : List[LogLevel] = list(log_level_set.keys()) + +ui_workflows : List[UiWorkflow] = [ 'instant_runner', 'job_runner', 'job_manager' ] +job_statuses : List[JobStatus] = [ 'drafted', 'queued', 'completed', 'failed' ] + +benchmark_cycle_count_range : Sequence[int] = create_int_range(1, 10, 1) +execution_thread_count_range : Sequence[int] = create_int_range(1, 32, 1) +execution_queue_count_range : Sequence[int] = create_int_range(1, 4, 1) +system_memory_limit_range : Sequence[int] = create_int_range(0, 128, 4) +face_detector_angles : Sequence[Angle] = create_int_range(0, 270, 90) +face_detector_score_range : Sequence[Score] = create_float_range(0.0, 1.0, 0.05) +face_landmarker_score_range : Sequence[Score] = create_float_range(0.0, 1.0, 0.05) +face_mask_blur_range : Sequence[float] = create_float_range(0.0, 1.0, 0.05) +face_mask_padding_range : Sequence[int] = create_int_range(0, 100, 1) +face_selector_age_range : Sequence[int] = create_int_range(0, 100, 1) +reference_face_distance_range : Sequence[float] = create_float_range(0.0, 1.0, 0.05) +output_image_quality_range : Sequence[int] = create_int_range(0, 100, 1) +output_audio_quality_range : Sequence[int] = create_int_range(0, 100, 1) +output_audio_volume_range : Sequence[int] = create_int_range(0, 100, 1) +output_video_quality_range : Sequence[int] = create_int_range(0, 100, 1) diff --git a/facefusion/cli_helper.py b/facefusion/cli_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..369cc0a6903dc0671494c009ddae2bc4d8c41d0f --- /dev/null +++ b/facefusion/cli_helper.py @@ -0,0 +1,35 @@ +from typing import Tuple + +from facefusion.logger import get_package_logger +from facefusion.types import TableContents, TableHeaders + + +def render_table(headers : TableHeaders, contents : TableContents) -> None: + package_logger = get_package_logger() + table_column, table_separator = create_table_parts(headers, contents) + + package_logger.critical(table_separator) + package_logger.critical(table_column.format(*headers)) + package_logger.critical(table_separator) + + for content in contents: + content = [ str(value) for value in content ] + package_logger.critical(table_column.format(*content)) + + package_logger.critical(table_separator) + + +def create_table_parts(headers : TableHeaders, contents : TableContents) -> Tuple[str, str]: + column_parts = [] + separator_parts = [] + widths = [ len(header) for header in headers ] + + for content in contents: + for index, value in enumerate(content): + widths[index] = max(widths[index], len(str(value))) + + for width in widths: + column_parts.append('{:<' + str(width) + '}') + separator_parts.append('-' * width) + + return '| ' + ' | '.join(column_parts) + ' |', '+-' + '-+-'.join(separator_parts) + '-+' diff --git a/facefusion/common_helper.py b/facefusion/common_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..b38ceb7090fd54553678e8f86d7d2d1062152cbb --- /dev/null +++ b/facefusion/common_helper.py @@ -0,0 +1,84 @@ +import platform +from typing import Any, Iterable, Optional, Reversible, Sequence + + +def is_linux() -> bool: + return platform.system().lower() == 'linux' + + +def is_macos() -> bool: + return platform.system().lower() == 'darwin' + + +def is_windows() -> bool: + return platform.system().lower() == 'windows' + + +def create_int_metavar(int_range : Sequence[int]) -> str: + return '[' + str(int_range[0]) + '..' + str(int_range[-1]) + ':' + str(calc_int_step(int_range)) + ']' + + +def create_float_metavar(float_range : Sequence[float]) -> str: + return '[' + str(float_range[0]) + '..' + str(float_range[-1]) + ':' + str(calc_float_step(float_range)) + ']' + + +def create_int_range(start : int, end : int, step : int) -> Sequence[int]: + int_range = [] + current = start + + while current <= end: + int_range.append(current) + current += step + return int_range + + +def create_float_range(start : float, end : float, step : float) -> Sequence[float]: + float_range = [] + current = start + + while current <= end: + float_range.append(round(current, 2)) + current = round(current + step, 2) + return float_range + + +def calc_int_step(int_range : Sequence[int]) -> int: + return int_range[1] - int_range[0] + + +def calc_float_step(float_range : Sequence[float]) -> float: + return round(float_range[1] - float_range[0], 2) + + +def cast_int(value : Any) -> Optional[int]: + try: + return int(value) + except (ValueError, TypeError): + return None + + +def cast_float(value : Any) -> Optional[float]: + try: + return float(value) + except (ValueError, TypeError): + return None + + +def cast_bool(value : Any) -> Optional[bool]: + if value == 'True': + return True + if value == 'False': + return False + return None + + +def get_first(__list__ : Any) -> Any: + if isinstance(__list__, Iterable): + return next(iter(__list__), None) + return None + + +def get_last(__list__ : Any) -> Any: + if isinstance(__list__, Reversible): + return next(reversed(__list__), None) + return None diff --git a/facefusion/config.py b/facefusion/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e307ce1964e97f17d134568fa96a4c9a6554eb --- /dev/null +++ b/facefusion/config.py @@ -0,0 +1,74 @@ +from configparser import ConfigParser +from typing import List, Optional + +from facefusion import state_manager +from facefusion.common_helper import cast_bool, cast_float, cast_int + +CONFIG_PARSER = None + + +def get_config_parser() -> ConfigParser: + global CONFIG_PARSER + + if CONFIG_PARSER is None: + CONFIG_PARSER = ConfigParser() + CONFIG_PARSER.read(state_manager.get_item('config_path'), encoding = 'utf-8') + return CONFIG_PARSER + + +def clear_config_parser() -> None: + global CONFIG_PARSER + + CONFIG_PARSER = None + + +def get_str_value(section : str, option : str, fallback : Optional[str] = None) -> Optional[str]: + config_parser = get_config_parser() + + if config_parser.has_option(section, option) and config_parser.get(section, option).strip(): + return config_parser.get(section, option) + return fallback + + +def get_int_value(section : str, option : str, fallback : Optional[str] = None) -> Optional[int]: + config_parser = get_config_parser() + + if config_parser.has_option(section, option) and config_parser.get(section, option).strip(): + return config_parser.getint(section, option) + return cast_int(fallback) + + +def get_float_value(section : str, option : str, fallback : Optional[str] = None) -> Optional[float]: + config_parser = get_config_parser() + + if config_parser.has_option(section, option) and config_parser.get(section, option).strip(): + return config_parser.getfloat(section, option) + return cast_float(fallback) + + +def get_bool_value(section : str, option : str, fallback : Optional[str] = None) -> Optional[bool]: + config_parser = get_config_parser() + + if config_parser.has_option(section, option) and config_parser.get(section, option).strip(): + return config_parser.getboolean(section, option) + return cast_bool(fallback) + + +def get_str_list(section : str, option : str, fallback : Optional[str] = None) -> Optional[List[str]]: + config_parser = get_config_parser() + + if config_parser.has_option(section, option) and config_parser.get(section, option).strip(): + return config_parser.get(section, option).split() + if fallback: + return fallback.split() + return None + + +def get_int_list(section : str, option : str, fallback : Optional[str] = None) -> Optional[List[int]]: + config_parser = get_config_parser() + + if config_parser.has_option(section, option) and config_parser.get(section, option).strip(): + return list(map(int, config_parser.get(section, option).split())) + if fallback: + return list(map(int, fallback.split())) + return None diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4cfa1e9d1d1dab1b207401aa68ce363f7beb66 --- /dev/null +++ b/facefusion/content_analyser.py @@ -0,0 +1,225 @@ +from functools import lru_cache +from typing import List, Tuple + +import numpy +from tqdm import tqdm + +from facefusion import inference_manager, state_manager, wording +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.execution import has_execution_provider +from facefusion.filesystem import resolve_relative_path +from facefusion.thread_helper import conditional_thread_semaphore +from facefusion.types import Detection, DownloadScope, DownloadSet, ExecutionProvider, Fps, InferencePool, ModelSet, VisionFrame +from facefusion.vision import detect_video_fps, fit_frame, read_image, read_video_frame + +STREAM_COUNTER = 0 + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'nsfw_1': + { + 'hashes': + { + 'content_analyser': + { + 'url': resolve_download_url('models-3.3.0', 'nsfw_1.hash'), + 'path': resolve_relative_path('../.assets/models/nsfw_1.hash') + } + }, + 'sources': + { + 'content_analyser': + { + 'url': resolve_download_url('models-3.3.0', 'nsfw_1.onnx'), + 'path': resolve_relative_path('../.assets/models/nsfw_1.onnx') + } + }, + 'size': (640, 640), + 'mean': (0.0, 0.0, 0.0), + 'standard_deviation': (1.0, 1.0, 1.0) + }, + 'nsfw_2': + { + 'hashes': + { + 'content_analyser': + { + 'url': resolve_download_url('models-3.3.0', 'nsfw_2.hash'), + 'path': resolve_relative_path('../.assets/models/nsfw_2.hash') + } + }, + 'sources': + { + 'content_analyser': + { + 'url': resolve_download_url('models-3.3.0', 'nsfw_2.onnx'), + 'path': resolve_relative_path('../.assets/models/nsfw_2.onnx') + } + }, + 'size': (384, 384), + 'mean': (0.5, 0.5, 0.5), + 'standard_deviation': (0.5, 0.5, 0.5) + }, + 'nsfw_3': + { + 'hashes': + { + 'content_analyser': + { + 'url': resolve_download_url('models-3.3.0', 'nsfw_3.hash'), + 'path': resolve_relative_path('../.assets/models/nsfw_3.hash') + } + }, + 'sources': + { + 'content_analyser': + { + 'url': resolve_download_url('models-3.3.0', 'nsfw_3.onnx'), + 'path': resolve_relative_path('../.assets/models/nsfw_3.onnx') + } + }, + 'size': (448, 448), + 'mean': (0.48145466, 0.4578275, 0.40821073), + 'standard_deviation': (0.26862954, 0.26130258, 0.27577711) + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ] + _, model_source_set = collect_model_downloads() + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def resolve_execution_providers() -> List[ExecutionProvider]: + if has_execution_provider('coreml'): + return [ 'cpu' ] + return state_manager.get_item('execution_providers') + + +def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: + model_set = create_static_model_set('full') + model_hash_set = {} + model_source_set = {} + + for content_analyser_model in [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]: + model_hash_set[content_analyser_model] = model_set.get(content_analyser_model).get('hashes').get('content_analyser') + model_source_set[content_analyser_model] = model_set.get(content_analyser_model).get('sources').get('content_analyser') + + return model_hash_set, model_source_set + + +def pre_check() -> bool: + model_hash_set, model_source_set = collect_model_downloads() + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool: + global STREAM_COUNTER + + STREAM_COUNTER = STREAM_COUNTER + 1 + if STREAM_COUNTER % int(video_fps) == 0: + return analyse_frame(vision_frame) + return False + + +def analyse_frame(vision_frame : VisionFrame) -> bool: + return detect_nsfw(vision_frame) + + +@lru_cache(maxsize = None) +def analyse_image(image_path : str) -> bool: + vision_frame = read_image(image_path) + return analyse_frame(vision_frame) + + +@lru_cache(maxsize = None) +def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int) -> bool: + video_fps = detect_video_fps(video_path) + frame_range = range(trim_frame_start, trim_frame_end) + rate = 0.0 + total = 0 + counter = 0 + + with tqdm(total = len(frame_range), desc = wording.get('analysing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: + + for frame_number in frame_range: + if frame_number % int(video_fps) == 0: + vision_frame = read_video_frame(video_path, frame_number) + total += 1 + if analyse_frame(vision_frame): + counter += 1 + if counter > 0 and total > 0: + rate = counter / total * 100 + progress.set_postfix(rate = rate) + progress.update() + + return bool(rate > 10.0) + + +def detect_nsfw(vision_frame : VisionFrame) -> bool: + is_nsfw_1 = detect_with_nsfw_1(vision_frame) + is_nsfw_2 = detect_with_nsfw_2(vision_frame) + is_nsfw_3 = detect_with_nsfw_3(vision_frame) + + return is_nsfw_1 and is_nsfw_2 or is_nsfw_1 and is_nsfw_3 or is_nsfw_2 and is_nsfw_3 + + +def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool: + detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1') + detection = forward_nsfw(detect_vision_frame, 'nsfw_1') + detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1)) + return bool(detection_score > 0.2) + + +def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool: + detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2') + detection = forward_nsfw(detect_vision_frame, 'nsfw_2') + detection_score = detection[0] - detection[1] + return bool(detection_score > 0.25) + + +def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool: + detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3') + detection = forward_nsfw(detect_vision_frame, 'nsfw_3') + detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1]) + return bool(detection_score > 10.5) + + +def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection: + content_analyser = get_inference_pool().get(nsfw_model) + + with conditional_thread_semaphore(): + detection = content_analyser.run(None, + { + 'input': vision_frame + })[0] + + if nsfw_model in [ 'nsfw_2', 'nsfw_3' ]: + return detection[0] + + return detection + + +def prepare_detect_frame(temp_vision_frame : VisionFrame, model_name : str) -> VisionFrame: + model_set = create_static_model_set('full').get(model_name) + model_size = model_set.get('size') + model_mean = model_set.get('mean') + model_standard_deviation = model_set.get('standard_deviation') + + detect_vision_frame = fit_frame(temp_vision_frame, model_size) + detect_vision_frame = detect_vision_frame[:, :, ::-1] / 255.0 + detect_vision_frame -= model_mean + detect_vision_frame /= model_standard_deviation + detect_vision_frame = numpy.expand_dims(detect_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) + return detect_vision_frame diff --git a/facefusion/core.py b/facefusion/core.py new file mode 100644 index 0000000000000000000000000000000000000000..abc28ef7f0eebddf09e5a93740f63cd228202bdf --- /dev/null +++ b/facefusion/core.py @@ -0,0 +1,517 @@ +import inspect +import itertools +import shutil +import signal +import sys +from time import time + +import numpy + +from facefusion import benchmarker, cli_helper, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, hash_helper, logger, process_manager, state_manager, video_manager, voice_extractor, wording +from facefusion.args import apply_args, collect_job_args, reduce_job_args, reduce_step_args +from facefusion.common_helper import get_first +from facefusion.content_analyser import analyse_image, analyse_video +from facefusion.download import conditional_download_hashes, conditional_download_sources +from facefusion.exit_helper import hard_exit, signal_exit +from facefusion.face_analyser import get_average_face, get_many_faces, get_one_face +from facefusion.face_selector import sort_and_filter_faces +from facefusion.face_store import append_reference_face, clear_reference_faces, get_reference_faces +from facefusion.ffmpeg import copy_image, extract_frames, finalize_image, merge_video, replace_audio, restore_audio +from facefusion.filesystem import filter_audio_paths, get_file_name, is_image, is_video, resolve_file_paths, resolve_file_pattern +from facefusion.jobs import job_helper, job_manager, job_runner +from facefusion.jobs.job_list import compose_job_list +from facefusion.memory import limit_system_memory +from facefusion.processors.core import get_processors_modules +from facefusion.program import create_program +from facefusion.program_helper import validate_args +from facefusion.temp_helper import clear_temp_directory, create_temp_directory, get_temp_file_path, move_temp_file, resolve_temp_frame_paths +from facefusion.types import Args, ErrorCode +from facefusion.vision import pack_resolution, read_image, read_static_images, read_video_frame, restrict_image_resolution, restrict_trim_frame, restrict_video_fps, restrict_video_resolution, unpack_resolution + + +def cli() -> None: + if pre_check(): + signal.signal(signal.SIGINT, signal_exit) + program = create_program() + + if validate_args(program): + args = vars(program.parse_args()) + apply_args(args, state_manager.init_item) + + if state_manager.get_item('command'): + logger.init(state_manager.get_item('log_level')) + route(args) + else: + program.print_help() + else: + hard_exit(2) + else: + hard_exit(2) + + +def route(args : Args) -> None: + system_memory_limit = state_manager.get_item('system_memory_limit') + + if system_memory_limit and system_memory_limit > 0: + limit_system_memory(system_memory_limit) + + if state_manager.get_item('command') == 'force-download': + error_code = force_download() + return hard_exit(error_code) + + if state_manager.get_item('command') == 'benchmark': + if not common_pre_check() or not processors_pre_check() or not benchmarker.pre_check(): + return hard_exit(2) + benchmarker.render() + + if state_manager.get_item('command') in [ 'job-list', 'job-create', 'job-submit', 'job-submit-all', 'job-delete', 'job-delete-all', 'job-add-step', 'job-remix-step', 'job-insert-step', 'job-remove-step' ]: + if not job_manager.init_jobs(state_manager.get_item('jobs_path')): + hard_exit(1) + error_code = route_job_manager(args) + hard_exit(error_code) + + if state_manager.get_item('command') == 'run': + import facefusion.uis.core as ui + + if not common_pre_check() or not processors_pre_check(): + return hard_exit(2) + for ui_layout in ui.get_ui_layouts_modules(state_manager.get_item('ui_layouts')): + if not ui_layout.pre_check(): + return hard_exit(2) + ui.init() + ui.launch() + + if state_manager.get_item('command') == 'headless-run': + if not job_manager.init_jobs(state_manager.get_item('jobs_path')): + hard_exit(1) + error_core = process_headless(args) + hard_exit(error_core) + + if state_manager.get_item('command') == 'batch-run': + if not job_manager.init_jobs(state_manager.get_item('jobs_path')): + hard_exit(1) + error_core = process_batch(args) + hard_exit(error_core) + + if state_manager.get_item('command') in [ 'job-run', 'job-run-all', 'job-retry', 'job-retry-all' ]: + if not job_manager.init_jobs(state_manager.get_item('jobs_path')): + hard_exit(1) + error_code = route_job_runner() + hard_exit(error_code) + + +def pre_check() -> bool: + if sys.version_info < (3, 10): + logger.error(wording.get('python_not_supported').format(version = '3.10'), __name__) + return False + + if not shutil.which('curl'): + logger.error(wording.get('curl_not_installed'), __name__) + return False + + if not shutil.which('ffmpeg'): + logger.error(wording.get('ffmpeg_not_installed'), __name__) + return False + return True + + +def common_pre_check() -> bool: + common_modules =\ + [ + content_analyser, + face_classifier, + face_detector, + face_landmarker, + face_masker, + face_recognizer, + voice_extractor + ] + + content_analyser_content = inspect.getsource(content_analyser).encode() + is_valid = hash_helper.create_hash(content_analyser_content) == 'b159fd9d' + + return all(module.pre_check() for module in common_modules) and is_valid + + +def processors_pre_check() -> bool: + for processor_module in get_processors_modules(state_manager.get_item('processors')): + if not processor_module.pre_check(): + return False + return True + + +def force_download() -> ErrorCode: + common_modules =\ + [ + content_analyser, + face_classifier, + face_detector, + face_landmarker, + face_masker, + face_recognizer, + voice_extractor + ] + available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] + processor_modules = get_processors_modules(available_processors) + + for module in common_modules + processor_modules: + if hasattr(module, 'create_static_model_set'): + for model in module.create_static_model_set(state_manager.get_item('download_scope')).values(): + model_hash_set = model.get('hashes') + model_source_set = model.get('sources') + + if model_hash_set and model_source_set: + if not conditional_download_hashes(model_hash_set) or not conditional_download_sources(model_source_set): + return 1 + + return 0 + + +def route_job_manager(args : Args) -> ErrorCode: + if state_manager.get_item('command') == 'job-list': + job_headers, job_contents = compose_job_list(state_manager.get_item('job_status')) + + if job_contents: + cli_helper.render_table(job_headers, job_contents) + return 0 + return 1 + + if state_manager.get_item('command') == 'job-create': + if job_manager.create_job(state_manager.get_item('job_id')): + logger.info(wording.get('job_created').format(job_id = state_manager.get_item('job_id')), __name__) + return 0 + logger.error(wording.get('job_not_created').format(job_id = state_manager.get_item('job_id')), __name__) + return 1 + + if state_manager.get_item('command') == 'job-submit': + if job_manager.submit_job(state_manager.get_item('job_id')): + logger.info(wording.get('job_submitted').format(job_id = state_manager.get_item('job_id')), __name__) + return 0 + logger.error(wording.get('job_not_submitted').format(job_id = state_manager.get_item('job_id')), __name__) + return 1 + + if state_manager.get_item('command') == 'job-submit-all': + if job_manager.submit_jobs(state_manager.get_item('halt_on_error')): + logger.info(wording.get('job_all_submitted'), __name__) + return 0 + logger.error(wording.get('job_all_not_submitted'), __name__) + return 1 + + if state_manager.get_item('command') == 'job-delete': + if job_manager.delete_job(state_manager.get_item('job_id')): + logger.info(wording.get('job_deleted').format(job_id = state_manager.get_item('job_id')), __name__) + return 0 + logger.error(wording.get('job_not_deleted').format(job_id = state_manager.get_item('job_id')), __name__) + return 1 + + if state_manager.get_item('command') == 'job-delete-all': + if job_manager.delete_jobs(state_manager.get_item('halt_on_error')): + logger.info(wording.get('job_all_deleted'), __name__) + return 0 + logger.error(wording.get('job_all_not_deleted'), __name__) + return 1 + + if state_manager.get_item('command') == 'job-add-step': + step_args = reduce_step_args(args) + + if job_manager.add_step(state_manager.get_item('job_id'), step_args): + logger.info(wording.get('job_step_added').format(job_id = state_manager.get_item('job_id')), __name__) + return 0 + logger.error(wording.get('job_step_not_added').format(job_id = state_manager.get_item('job_id')), __name__) + return 1 + + if state_manager.get_item('command') == 'job-remix-step': + step_args = reduce_step_args(args) + + if job_manager.remix_step(state_manager.get_item('job_id'), state_manager.get_item('step_index'), step_args): + logger.info(wording.get('job_remix_step_added').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__) + return 0 + logger.error(wording.get('job_remix_step_not_added').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__) + return 1 + + if state_manager.get_item('command') == 'job-insert-step': + step_args = reduce_step_args(args) + + if job_manager.insert_step(state_manager.get_item('job_id'), state_manager.get_item('step_index'), step_args): + logger.info(wording.get('job_step_inserted').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__) + return 0 + logger.error(wording.get('job_step_not_inserted').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__) + return 1 + + if state_manager.get_item('command') == 'job-remove-step': + if job_manager.remove_step(state_manager.get_item('job_id'), state_manager.get_item('step_index')): + logger.info(wording.get('job_step_removed').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__) + return 0 + logger.error(wording.get('job_step_not_removed').format(job_id = state_manager.get_item('job_id'), step_index = state_manager.get_item('step_index')), __name__) + return 1 + return 1 + + +def route_job_runner() -> ErrorCode: + if state_manager.get_item('command') == 'job-run': + logger.info(wording.get('running_job').format(job_id = state_manager.get_item('job_id')), __name__) + if job_runner.run_job(state_manager.get_item('job_id'), process_step): + logger.info(wording.get('processing_job_succeed').format(job_id = state_manager.get_item('job_id')), __name__) + return 0 + logger.info(wording.get('processing_job_failed').format(job_id = state_manager.get_item('job_id')), __name__) + return 1 + + if state_manager.get_item('command') == 'job-run-all': + logger.info(wording.get('running_jobs'), __name__) + if job_runner.run_jobs(process_step, state_manager.get_item('halt_on_error')): + logger.info(wording.get('processing_jobs_succeed'), __name__) + return 0 + logger.info(wording.get('processing_jobs_failed'), __name__) + return 1 + + if state_manager.get_item('command') == 'job-retry': + logger.info(wording.get('retrying_job').format(job_id = state_manager.get_item('job_id')), __name__) + if job_runner.retry_job(state_manager.get_item('job_id'), process_step): + logger.info(wording.get('processing_job_succeed').format(job_id = state_manager.get_item('job_id')), __name__) + return 0 + logger.info(wording.get('processing_job_failed').format(job_id = state_manager.get_item('job_id')), __name__) + return 1 + + if state_manager.get_item('command') == 'job-retry-all': + logger.info(wording.get('retrying_jobs'), __name__) + if job_runner.retry_jobs(process_step, state_manager.get_item('halt_on_error')): + logger.info(wording.get('processing_jobs_succeed'), __name__) + return 0 + logger.info(wording.get('processing_jobs_failed'), __name__) + return 1 + return 2 + + +def process_headless(args : Args) -> ErrorCode: + job_id = job_helper.suggest_job_id('headless') + step_args = reduce_step_args(args) + + if job_manager.create_job(job_id) and job_manager.add_step(job_id, step_args) and job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step): + return 0 + return 1 + + +def process_batch(args : Args) -> ErrorCode: + job_id = job_helper.suggest_job_id('batch') + step_args = reduce_step_args(args) + job_args = reduce_job_args(args) + source_paths = resolve_file_pattern(job_args.get('source_pattern')) + target_paths = resolve_file_pattern(job_args.get('target_pattern')) + + if job_manager.create_job(job_id): + if source_paths and target_paths: + for index, (source_path, target_path) in enumerate(itertools.product(source_paths, target_paths)): + step_args['source_paths'] = [ source_path ] + step_args['target_path'] = target_path + step_args['output_path'] = job_args.get('output_pattern').format(index = index) + if not job_manager.add_step(job_id, step_args): + return 1 + if job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step): + return 0 + + if not source_paths and target_paths: + for index, target_path in enumerate(target_paths): + step_args['target_path'] = target_path + step_args['output_path'] = job_args.get('output_pattern').format(index = index) + if not job_manager.add_step(job_id, step_args): + return 1 + if job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step): + return 0 + return 1 + + +def process_step(job_id : str, step_index : int, step_args : Args) -> bool: + clear_reference_faces() + step_total = job_manager.count_step_total(job_id) + step_args.update(collect_job_args()) + apply_args(step_args, state_manager.set_item) + + logger.info(wording.get('processing_step').format(step_current = step_index + 1, step_total = step_total), __name__) + if common_pre_check() and processors_pre_check(): + error_code = conditional_process() + return error_code == 0 + return False + + +def conditional_process() -> ErrorCode: + start_time = time() + + for processor_module in get_processors_modules(state_manager.get_item('processors')): + if not processor_module.pre_process('output'): + return 2 + + conditional_append_reference_faces() + + if is_image(state_manager.get_item('target_path')): + return process_image(start_time) + if is_video(state_manager.get_item('target_path')): + return process_video(start_time) + + return 0 + + +def conditional_append_reference_faces() -> None: + if 'reference' in state_manager.get_item('face_selector_mode') and not get_reference_faces(): + source_frames = read_static_images(state_manager.get_item('source_paths')) + source_faces = get_many_faces(source_frames) + source_face = get_average_face(source_faces) + if is_video(state_manager.get_item('target_path')): + reference_frame = read_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number')) + else: + reference_frame = read_image(state_manager.get_item('target_path')) + reference_faces = sort_and_filter_faces(get_many_faces([ reference_frame ])) + reference_face = get_one_face(reference_faces, state_manager.get_item('reference_face_position')) + append_reference_face('origin', reference_face) + + if source_face and reference_face: + for processor_module in get_processors_modules(state_manager.get_item('processors')): + abstract_reference_frame = processor_module.get_reference_frame(source_face, reference_face, reference_frame) + if numpy.any(abstract_reference_frame): + abstract_reference_faces = sort_and_filter_faces(get_many_faces([ abstract_reference_frame ])) + abstract_reference_face = get_one_face(abstract_reference_faces, state_manager.get_item('reference_face_position')) + append_reference_face(processor_module.__name__, abstract_reference_face) + + +def process_image(start_time : float) -> ErrorCode: + if analyse_image(state_manager.get_item('target_path')): + return 3 + + logger.debug(wording.get('clearing_temp'), __name__) + clear_temp_directory(state_manager.get_item('target_path')) + logger.debug(wording.get('creating_temp'), __name__) + create_temp_directory(state_manager.get_item('target_path')) + + process_manager.start() + temp_image_resolution = pack_resolution(restrict_image_resolution(state_manager.get_item('target_path'), unpack_resolution(state_manager.get_item('output_image_resolution')))) + logger.info(wording.get('copying_image').format(resolution = temp_image_resolution), __name__) + if copy_image(state_manager.get_item('target_path'), temp_image_resolution): + logger.debug(wording.get('copying_image_succeed'), __name__) + else: + logger.error(wording.get('copying_image_failed'), __name__) + process_manager.end() + return 1 + + temp_image_path = get_temp_file_path(state_manager.get_item('target_path')) + for processor_module in get_processors_modules(state_manager.get_item('processors')): + logger.info(wording.get('processing'), processor_module.__name__) + processor_module.process_image(state_manager.get_item('source_paths'), temp_image_path, temp_image_path) + processor_module.post_process() + if is_process_stopping(): + process_manager.end() + return 4 + + logger.info(wording.get('finalizing_image').format(resolution = state_manager.get_item('output_image_resolution')), __name__) + if finalize_image(state_manager.get_item('target_path'), state_manager.get_item('output_path'), state_manager.get_item('output_image_resolution')): + logger.debug(wording.get('finalizing_image_succeed'), __name__) + else: + logger.warn(wording.get('finalizing_image_skipped'), __name__) + + logger.debug(wording.get('clearing_temp'), __name__) + clear_temp_directory(state_manager.get_item('target_path')) + + if is_image(state_manager.get_item('output_path')): + seconds = '{:.2f}'.format((time() - start_time) % 60) + logger.info(wording.get('processing_image_succeed').format(seconds = seconds), __name__) + else: + logger.error(wording.get('processing_image_failed'), __name__) + process_manager.end() + return 1 + process_manager.end() + return 0 + + +def process_video(start_time : float) -> ErrorCode: + trim_frame_start, trim_frame_end = restrict_trim_frame(state_manager.get_item('target_path'), state_manager.get_item('trim_frame_start'), state_manager.get_item('trim_frame_end')) + if analyse_video(state_manager.get_item('target_path'), trim_frame_start, trim_frame_end): + return 3 + + logger.debug(wording.get('clearing_temp'), __name__) + clear_temp_directory(state_manager.get_item('target_path')) + logger.debug(wording.get('creating_temp'), __name__) + create_temp_directory(state_manager.get_item('target_path')) + + process_manager.start() + temp_video_resolution = pack_resolution(restrict_video_resolution(state_manager.get_item('target_path'), unpack_resolution(state_manager.get_item('output_video_resolution')))) + temp_video_fps = restrict_video_fps(state_manager.get_item('target_path'), state_manager.get_item('output_video_fps')) + logger.info(wording.get('extracting_frames').format(resolution = temp_video_resolution, fps = temp_video_fps), __name__) + if extract_frames(state_manager.get_item('target_path'), temp_video_resolution, temp_video_fps, trim_frame_start, trim_frame_end): + logger.debug(wording.get('extracting_frames_succeed'), __name__) + else: + if is_process_stopping(): + process_manager.end() + return 4 + logger.error(wording.get('extracting_frames_failed'), __name__) + process_manager.end() + return 1 + + temp_frame_paths = resolve_temp_frame_paths(state_manager.get_item('target_path')) + if temp_frame_paths: + for processor_module in get_processors_modules(state_manager.get_item('processors')): + logger.info(wording.get('processing'), processor_module.__name__) + processor_module.process_video(state_manager.get_item('source_paths'), temp_frame_paths) + processor_module.post_process() + if is_process_stopping(): + return 4 + else: + logger.error(wording.get('temp_frames_not_found'), __name__) + process_manager.end() + return 1 + + logger.info(wording.get('merging_video').format(resolution = state_manager.get_item('output_video_resolution'), fps = state_manager.get_item('output_video_fps')), __name__) + if merge_video(state_manager.get_item('target_path'), temp_video_fps, state_manager.get_item('output_video_resolution'), state_manager.get_item('output_video_fps'), trim_frame_start, trim_frame_end): + logger.debug(wording.get('merging_video_succeed'), __name__) + else: + if is_process_stopping(): + process_manager.end() + return 4 + logger.error(wording.get('merging_video_failed'), __name__) + process_manager.end() + return 1 + + if state_manager.get_item('output_audio_volume') == 0: + logger.info(wording.get('skipping_audio'), __name__) + move_temp_file(state_manager.get_item('target_path'), state_manager.get_item('output_path')) + else: + source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths'))) + if source_audio_path: + if replace_audio(state_manager.get_item('target_path'), source_audio_path, state_manager.get_item('output_path')): + video_manager.clear_video_pool() + logger.debug(wording.get('replacing_audio_succeed'), __name__) + else: + video_manager.clear_video_pool() + if is_process_stopping(): + process_manager.end() + return 4 + logger.warn(wording.get('replacing_audio_skipped'), __name__) + move_temp_file(state_manager.get_item('target_path'), state_manager.get_item('output_path')) + else: + if restore_audio(state_manager.get_item('target_path'), state_manager.get_item('output_path'), trim_frame_start, trim_frame_end): + video_manager.clear_video_pool() + logger.debug(wording.get('restoring_audio_succeed'), __name__) + else: + video_manager.clear_video_pool() + if is_process_stopping(): + process_manager.end() + return 4 + logger.warn(wording.get('restoring_audio_skipped'), __name__) + move_temp_file(state_manager.get_item('target_path'), state_manager.get_item('output_path')) + + logger.debug(wording.get('clearing_temp'), __name__) + clear_temp_directory(state_manager.get_item('target_path')) + + if is_video(state_manager.get_item('output_path')): + seconds = '{:.2f}'.format((time() - start_time)) + logger.info(wording.get('processing_video_succeed').format(seconds = seconds), __name__) + else: + logger.error(wording.get('processing_video_failed'), __name__) + process_manager.end() + return 1 + process_manager.end() + return 0 + + +def is_process_stopping() -> bool: + if process_manager.is_stopping(): + process_manager.end() + logger.info(wording.get('processing_stopped'), __name__) + return process_manager.is_pending() diff --git a/facefusion/curl_builder.py b/facefusion/curl_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..a720353b61197a14331e534ff78d4fa1a56a6051 --- /dev/null +++ b/facefusion/curl_builder.py @@ -0,0 +1,27 @@ +import itertools +import shutil + +from facefusion import metadata +from facefusion.types import Commands + + +def run(commands : Commands) -> Commands: + user_agent = metadata.get('name') + '/' + metadata.get('version') + + return [ shutil.which('curl'), '--user-agent', user_agent, '--insecure', '--location', '--silent' ] + commands + + +def chain(*commands : Commands) -> Commands: + return list(itertools.chain(*commands)) + + +def head(url : str) -> Commands: + return [ '-I', url ] + + +def download(url : str, download_file_path : str) -> Commands: + return [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url ] + + +def set_timeout(timeout : int) -> Commands: + return [ '--connect-timeout', str(timeout) ] diff --git a/facefusion/date_helper.py b/facefusion/date_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..c60e2f6109e0fefc5d9973cf8e63413e5a2cbdf0 --- /dev/null +++ b/facefusion/date_helper.py @@ -0,0 +1,28 @@ +from datetime import datetime, timedelta +from typing import Optional, Tuple + +from facefusion import wording + + +def get_current_date_time() -> datetime: + return datetime.now().astimezone() + + +def split_time_delta(time_delta : timedelta) -> Tuple[int, int, int, int]: + days, hours = divmod(time_delta.total_seconds(), 86400) + hours, minutes = divmod(hours, 3600) + minutes, seconds = divmod(minutes, 60) + return int(days), int(hours), int(minutes), int(seconds) + + +def describe_time_ago(date_time : datetime) -> Optional[str]: + time_ago = datetime.now(date_time.tzinfo) - date_time + days, hours, minutes, _ = split_time_delta(time_ago) + + if timedelta(days = 1) < time_ago: + return wording.get('time_ago_days').format(days = days, hours = hours, minutes = minutes) + if timedelta(hours = 1) < time_ago: + return wording.get('time_ago_hours').format(hours = hours, minutes = minutes) + if timedelta(minutes = 1) < time_ago: + return wording.get('time_ago_minutes').format(minutes = minutes) + return wording.get('time_ago_now') diff --git a/facefusion/download.py b/facefusion/download.py new file mode 100644 index 0000000000000000000000000000000000000000..f0c92f4afde5cee3ad3f068877aea0002a6f5dc6 --- /dev/null +++ b/facefusion/download.py @@ -0,0 +1,174 @@ +import os +import subprocess +from functools import lru_cache +from typing import List, Optional, Tuple +from urllib.parse import urlparse + +from tqdm import tqdm + +import facefusion.choices +from facefusion import curl_builder, logger, process_manager, state_manager, wording +from facefusion.filesystem import get_file_name, get_file_size, is_file, remove_file +from facefusion.hash_helper import validate_hash +from facefusion.types import Commands, DownloadProvider, DownloadSet + + +def open_curl(commands : Commands) -> subprocess.Popen[bytes]: + commands = curl_builder.run(commands) + return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE) + + +def conditional_download(download_directory_path : str, urls : List[str]) -> None: + for url in urls: + download_file_name = os.path.basename(urlparse(url).path) + download_file_path = os.path.join(download_directory_path, download_file_name) + initial_size = get_file_size(download_file_path) + download_size = get_static_download_size(url) + + if initial_size < download_size: + with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: + commands = curl_builder.chain( + curl_builder.download(url, download_file_path), + curl_builder.set_timeout(10) + ) + open_curl(commands) + current_size = initial_size + progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name) + + while current_size < download_size: + if is_file(download_file_path): + current_size = get_file_size(download_file_path) + progress.update(current_size - progress.n) + + +@lru_cache(maxsize = None) +def get_static_download_size(url : str) -> int: + commands = curl_builder.chain( + curl_builder.head(url), + curl_builder.set_timeout(5) + ) + process = open_curl(commands) + lines = reversed(process.stdout.readlines()) + + for line in lines: + __line__ = line.decode().lower() + if 'content-length:' in __line__: + _, content_length = __line__.split('content-length:') + return int(content_length) + + return 0 + + +@lru_cache(maxsize = None) +def ping_static_url(url : str) -> bool: + commands = curl_builder.chain( + curl_builder.head(url), + curl_builder.set_timeout(5) + ) + process = open_curl(commands) + process.communicate() + return process.returncode == 0 + + +def conditional_download_hashes(hash_set : DownloadSet) -> bool: + hash_paths = [ hash_set.get(hash_key).get('path') for hash_key in hash_set.keys() ] + + process_manager.check() + _, invalid_hash_paths = validate_hash_paths(hash_paths) + if invalid_hash_paths: + for index in hash_set: + if hash_set.get(index).get('path') in invalid_hash_paths: + invalid_hash_url = hash_set.get(index).get('url') + if invalid_hash_url: + download_directory_path = os.path.dirname(hash_set.get(index).get('path')) + conditional_download(download_directory_path, [ invalid_hash_url ]) + + valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths) + + for valid_hash_path in valid_hash_paths: + valid_hash_file_name = get_file_name(valid_hash_path) + logger.debug(wording.get('validating_hash_succeed').format(hash_file_name = valid_hash_file_name), __name__) + for invalid_hash_path in invalid_hash_paths: + invalid_hash_file_name = get_file_name(invalid_hash_path) + logger.error(wording.get('validating_hash_failed').format(hash_file_name = invalid_hash_file_name), __name__) + + if not invalid_hash_paths: + process_manager.end() + return not invalid_hash_paths + + +def conditional_download_sources(source_set : DownloadSet) -> bool: + source_paths = [ source_set.get(source_key).get('path') for source_key in source_set.keys() ] + + process_manager.check() + _, invalid_source_paths = validate_source_paths(source_paths) + if invalid_source_paths: + for index in source_set: + if source_set.get(index).get('path') in invalid_source_paths: + invalid_source_url = source_set.get(index).get('url') + if invalid_source_url: + download_directory_path = os.path.dirname(source_set.get(index).get('path')) + conditional_download(download_directory_path, [ invalid_source_url ]) + + valid_source_paths, invalid_source_paths = validate_source_paths(source_paths) + + for valid_source_path in valid_source_paths: + valid_source_file_name = get_file_name(valid_source_path) + logger.debug(wording.get('validating_source_succeed').format(source_file_name = valid_source_file_name), __name__) + for invalid_source_path in invalid_source_paths: + invalid_source_file_name = get_file_name(invalid_source_path) + logger.error(wording.get('validating_source_failed').format(source_file_name = invalid_source_file_name), __name__) + + if remove_file(invalid_source_path): + logger.error(wording.get('deleting_corrupt_source').format(source_file_name = invalid_source_file_name), __name__) + + if not invalid_source_paths: + process_manager.end() + return not invalid_source_paths + + +def validate_hash_paths(hash_paths : List[str]) -> Tuple[List[str], List[str]]: + valid_hash_paths = [] + invalid_hash_paths = [] + + for hash_path in hash_paths: + if is_file(hash_path): + valid_hash_paths.append(hash_path) + else: + invalid_hash_paths.append(hash_path) + + return valid_hash_paths, invalid_hash_paths + + +def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str]]: + valid_source_paths = [] + invalid_source_paths = [] + + for source_path in source_paths: + if validate_hash(source_path): + valid_source_paths.append(source_path) + else: + invalid_source_paths.append(source_path) + + return valid_source_paths, invalid_source_paths + + +def resolve_download_url(base_name : str, file_name : str) -> Optional[str]: + download_providers = state_manager.get_item('download_providers') + + for download_provider in download_providers: + download_url = resolve_download_url_by_provider(download_provider, base_name, file_name) + if download_url: + return download_url + + return None + + +def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]: + download_provider_value = facefusion.choices.download_provider_set.get(download_provider) + + for download_provider_url in download_provider_value.get('urls'): + if ping_static_url(download_provider_url): + return download_provider_url + download_provider_value.get('path').format(base_name = base_name, file_name = file_name) + + return None diff --git a/facefusion/execution.py b/facefusion/execution.py new file mode 100644 index 0000000000000000000000000000000000000000..dbec8bfc268f6fb22491c37bcc153203e1a227e7 --- /dev/null +++ b/facefusion/execution.py @@ -0,0 +1,156 @@ +import shutil +import subprocess +import xml.etree.ElementTree as ElementTree +from functools import lru_cache +from typing import List, Optional + +from onnxruntime import get_available_providers, set_default_logger_severity + +import facefusion.choices +from facefusion.types import ExecutionDevice, ExecutionProvider, InferenceSessionProvider, ValueAndUnit + +set_default_logger_severity(3) + + +def has_execution_provider(execution_provider : ExecutionProvider) -> bool: + return execution_provider in get_available_execution_providers() + + +def get_available_execution_providers() -> List[ExecutionProvider]: + inference_session_providers = get_available_providers() + available_execution_providers : List[ExecutionProvider] = [] + + for execution_provider, execution_provider_value in facefusion.choices.execution_provider_set.items(): + if execution_provider_value in inference_session_providers: + index = facefusion.choices.execution_providers.index(execution_provider) + available_execution_providers.insert(index, execution_provider) + + return available_execution_providers + + +def create_inference_session_providers(execution_device_id : str, execution_providers : List[ExecutionProvider]) -> List[InferenceSessionProvider]: + inference_session_providers : List[InferenceSessionProvider] = [] + + for execution_provider in execution_providers: + if execution_provider == 'cuda': + inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), + { + 'device_id': execution_device_id, + 'cudnn_conv_algo_search': resolve_cudnn_conv_algo_search() + })) + if execution_provider == 'tensorrt': + inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), + { + 'device_id': execution_device_id, + 'trt_engine_cache_enable': True, + 'trt_engine_cache_path': '.caches', + 'trt_timing_cache_enable': True, + 'trt_timing_cache_path': '.caches', + 'trt_builder_optimization_level': 5 + })) + if execution_provider in [ 'directml', 'rocm' ]: + inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), + { + 'device_id': execution_device_id + })) + if execution_provider == 'openvino': + inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), + { + 'device_type': resolve_openvino_device_type(execution_device_id), + 'precision': 'FP32' + })) + if execution_provider == 'coreml': + inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), + { + 'SpecializationStrategy': 'FastPrediction', + 'ModelCacheDirectory': '.caches' + })) + + if 'cpu' in execution_providers: + inference_session_providers.append(facefusion.choices.execution_provider_set.get('cpu')) + + return inference_session_providers + + +def resolve_cudnn_conv_algo_search() -> str: + execution_devices = detect_static_execution_devices() + product_names = ('GeForce GTX 1630', 'GeForce GTX 1650', 'GeForce GTX 1660') + + for execution_device in execution_devices: + if execution_device.get('product').get('name').startswith(product_names): + return 'DEFAULT' + + return 'EXHAUSTIVE' + + +def resolve_openvino_device_type(execution_device_id : str) -> str: + if execution_device_id == '0': + return 'GPU' + if execution_device_id == '∞': + return 'MULTI:GPU' + return 'GPU.' + execution_device_id + + +def run_nvidia_smi() -> subprocess.Popen[bytes]: + commands = [ shutil.which('nvidia-smi'), '--query', '--xml-format' ] + return subprocess.Popen(commands, stdout = subprocess.PIPE) + + +@lru_cache(maxsize = None) +def detect_static_execution_devices() -> List[ExecutionDevice]: + return detect_execution_devices() + + +def detect_execution_devices() -> List[ExecutionDevice]: + execution_devices : List[ExecutionDevice] = [] + + try: + output, _ = run_nvidia_smi().communicate() + root_element = ElementTree.fromstring(output) + except Exception: + root_element = ElementTree.Element('xml') + + for gpu_element in root_element.findall('gpu'): + execution_devices.append( + { + 'driver_version': root_element.findtext('driver_version'), + 'framework': + { + 'name': 'CUDA', + 'version': root_element.findtext('cuda_version') + }, + 'product': + { + 'vendor': 'NVIDIA', + 'name': gpu_element.findtext('product_name').replace('NVIDIA', '').strip() + }, + 'video_memory': + { + 'total': create_value_and_unit(gpu_element.findtext('fb_memory_usage/total')), + 'free': create_value_and_unit(gpu_element.findtext('fb_memory_usage/free')) + }, + 'temperature': + { + 'gpu': create_value_and_unit(gpu_element.findtext('temperature/gpu_temp')), + 'memory': create_value_and_unit(gpu_element.findtext('temperature/memory_temp')) + }, + 'utilization': + { + 'gpu': create_value_and_unit(gpu_element.findtext('utilization/gpu_util')), + 'memory': create_value_and_unit(gpu_element.findtext('utilization/memory_util')) + } + }) + + return execution_devices + + +def create_value_and_unit(text : str) -> Optional[ValueAndUnit]: + if ' ' in text: + value, unit = text.split() + + return\ + { + 'value': int(value), + 'unit': str(unit) + } + return None diff --git a/facefusion/exit_helper.py b/facefusion/exit_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..23da74a99d17882ac21a4d0770664c6354a8452c --- /dev/null +++ b/facefusion/exit_helper.py @@ -0,0 +1,26 @@ +import signal +import sys +from time import sleep +from types import FrameType + +from facefusion import process_manager, state_manager +from facefusion.temp_helper import clear_temp_directory +from facefusion.types import ErrorCode + + +def hard_exit(error_code : ErrorCode) -> None: + signal.signal(signal.SIGINT, signal.SIG_IGN) + sys.exit(error_code) + + +def signal_exit(signum : int, frame : FrameType) -> None: + graceful_exit(0) + + +def graceful_exit(error_code : ErrorCode) -> None: + process_manager.stop() + while process_manager.is_processing(): + sleep(0.5) + if state_manager.get_item('target_path'): + clear_temp_directory(state_manager.get_item('target_path')) + hard_exit(error_code) diff --git a/facefusion/face_analyser.py b/facefusion/face_analyser.py new file mode 100644 index 0000000000000000000000000000000000000000..673ecbea92d140440bc950adac69278b9214de3f --- /dev/null +++ b/facefusion/face_analyser.py @@ -0,0 +1,124 @@ +from typing import List, Optional + +import numpy + +from facefusion import state_manager +from facefusion.common_helper import get_first +from facefusion.face_classifier import classify_face +from facefusion.face_detector import detect_faces, detect_rotated_faces +from facefusion.face_helper import apply_nms, convert_to_face_landmark_5, estimate_face_angle, get_nms_threshold +from facefusion.face_landmarker import detect_face_landmark, estimate_face_landmark_68_5 +from facefusion.face_recognizer import calc_embedding +from facefusion.face_store import get_static_faces, set_static_faces +from facefusion.types import BoundingBox, Face, FaceLandmark5, FaceLandmarkSet, FaceScoreSet, Score, VisionFrame + + +def create_faces(vision_frame : VisionFrame, bounding_boxes : List[BoundingBox], face_scores : List[Score], face_landmarks_5 : List[FaceLandmark5]) -> List[Face]: + faces = [] + nms_threshold = get_nms_threshold(state_manager.get_item('face_detector_model'), state_manager.get_item('face_detector_angles')) + keep_indices = apply_nms(bounding_boxes, face_scores, state_manager.get_item('face_detector_score'), nms_threshold) + + for index in keep_indices: + bounding_box = bounding_boxes[index] + face_score = face_scores[index] + face_landmark_5 = face_landmarks_5[index] + face_landmark_5_68 = face_landmark_5 + face_landmark_68_5 = estimate_face_landmark_68_5(face_landmark_5_68) + face_landmark_68 = face_landmark_68_5 + face_landmark_score_68 = 0.0 + face_angle = estimate_face_angle(face_landmark_68_5) + + if state_manager.get_item('face_landmarker_score') > 0: + face_landmark_68, face_landmark_score_68 = detect_face_landmark(vision_frame, bounding_box, face_angle) + if face_landmark_score_68 > state_manager.get_item('face_landmarker_score'): + face_landmark_5_68 = convert_to_face_landmark_5(face_landmark_68) + + face_landmark_set : FaceLandmarkSet =\ + { + '5': face_landmark_5, + '5/68': face_landmark_5_68, + '68': face_landmark_68, + '68/5': face_landmark_68_5 + } + face_score_set : FaceScoreSet =\ + { + 'detector': face_score, + 'landmarker': face_landmark_score_68 + } + embedding, normed_embedding = calc_embedding(vision_frame, face_landmark_set.get('5/68')) + gender, age, race = classify_face(vision_frame, face_landmark_set.get('5/68')) + faces.append(Face( + bounding_box = bounding_box, + score_set = face_score_set, + landmark_set = face_landmark_set, + angle = face_angle, + embedding = embedding, + normed_embedding = normed_embedding, + gender = gender, + age = age, + race = race + )) + return faces + + +def get_one_face(faces : List[Face], position : int = 0) -> Optional[Face]: + if faces: + position = min(position, len(faces) - 1) + return faces[position] + return None + + +def get_average_face(faces : List[Face]) -> Optional[Face]: + embeddings = [] + normed_embeddings = [] + + if faces: + first_face = get_first(faces) + + for face in faces: + embeddings.append(face.embedding) + normed_embeddings.append(face.normed_embedding) + + return Face( + bounding_box = first_face.bounding_box, + score_set = first_face.score_set, + landmark_set = first_face.landmark_set, + angle = first_face.angle, + embedding = numpy.mean(embeddings, axis = 0), + normed_embedding = numpy.mean(normed_embeddings, axis = 0), + gender = first_face.gender, + age = first_face.age, + race = first_face.race + ) + return None + + +def get_many_faces(vision_frames : List[VisionFrame]) -> List[Face]: + many_faces : List[Face] = [] + + for vision_frame in vision_frames: + if numpy.any(vision_frame): + static_faces = get_static_faces(vision_frame) + if static_faces: + many_faces.extend(static_faces) + else: + all_bounding_boxes = [] + all_face_scores = [] + all_face_landmarks_5 = [] + + for face_detector_angle in state_manager.get_item('face_detector_angles'): + if face_detector_angle == 0: + bounding_boxes, face_scores, face_landmarks_5 = detect_faces(vision_frame) + else: + bounding_boxes, face_scores, face_landmarks_5 = detect_rotated_faces(vision_frame, face_detector_angle) + all_bounding_boxes.extend(bounding_boxes) + all_face_scores.extend(face_scores) + all_face_landmarks_5.extend(face_landmarks_5) + + if all_bounding_boxes and all_face_scores and all_face_landmarks_5 and state_manager.get_item('face_detector_score') > 0: + faces = create_faces(vision_frame, all_bounding_boxes, all_face_scores, all_face_landmarks_5) + + if faces: + many_faces.extend(faces) + set_static_faces(vision_frame, faces) + return many_faces diff --git a/facefusion/face_classifier.py b/facefusion/face_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..3b0999056597bd3f3520a805e29d5e12d9bbdf4b --- /dev/null +++ b/facefusion/face_classifier.py @@ -0,0 +1,134 @@ +from functools import lru_cache +from typing import List, Tuple + +import numpy + +from facefusion import inference_manager +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.face_helper import warp_face_by_face_landmark_5 +from facefusion.filesystem import resolve_relative_path +from facefusion.thread_helper import conditional_thread_semaphore +from facefusion.types import Age, DownloadScope, FaceLandmark5, Gender, InferencePool, ModelOptions, ModelSet, Race, VisionFrame + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'fairface': + { + 'hashes': + { + 'face_classifier': + { + 'url': resolve_download_url('models-3.0.0', 'fairface.hash'), + 'path': resolve_relative_path('../.assets/models/fairface.hash') + } + }, + 'sources': + { + 'face_classifier': + { + 'url': resolve_download_url('models-3.0.0', 'fairface.onnx'), + 'path': resolve_relative_path('../.assets/models/fairface.onnx') + } + }, + 'template': 'arcface_112_v2', + 'size': (224, 224), + 'mean': [ 0.485, 0.456, 0.406 ], + 'standard_deviation': [ 0.229, 0.224, 0.225 ] + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ 'fairface' ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ 'fairface' ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + return create_static_model_set('full').get('fairface') + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def classify_face(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Gender, Age, Race]: + model_template = get_model_options().get('template') + model_size = get_model_options().get('size') + model_mean = get_model_options().get('mean') + model_standard_deviation = get_model_options().get('standard_deviation') + crop_vision_frame, _ = warp_face_by_face_landmark_5(temp_vision_frame, face_landmark_5, model_template, model_size) + crop_vision_frame = crop_vision_frame.astype(numpy.float32)[:, :, ::-1] / 255.0 + crop_vision_frame -= model_mean + crop_vision_frame /= model_standard_deviation + crop_vision_frame = crop_vision_frame.transpose(2, 0, 1) + crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0) + gender_id, age_id, race_id = forward(crop_vision_frame) + gender = categorize_gender(gender_id[0]) + age = categorize_age(age_id[0]) + race = categorize_race(race_id[0]) + return gender, age, race + + +def forward(crop_vision_frame : VisionFrame) -> Tuple[List[int], List[int], List[int]]: + face_classifier = get_inference_pool().get('face_classifier') + + with conditional_thread_semaphore(): + race_id, gender_id, age_id = face_classifier.run(None, + { + 'input': crop_vision_frame + }) + + return gender_id, age_id, race_id + + +def categorize_gender(gender_id : int) -> Gender: + if gender_id == 1: + return 'female' + return 'male' + + +def categorize_age(age_id : int) -> Age: + if age_id == 0: + return range(0, 2) + if age_id == 1: + return range(3, 9) + if age_id == 2: + return range(10, 19) + if age_id == 3: + return range(20, 29) + if age_id == 4: + return range(30, 39) + if age_id == 5: + return range(40, 49) + if age_id == 6: + return range(50, 59) + if age_id == 7: + return range(60, 69) + return range(70, 100) + + +def categorize_race(race_id : int) -> Race: + if race_id == 1: + return 'black' + if race_id == 2: + return 'latino' + if race_id == 3 or race_id == 4: + return 'asian' + if race_id == 5: + return 'indian' + if race_id == 6: + return 'arabic' + return 'white' diff --git a/facefusion/face_detector.py b/facefusion/face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..c3532fdf5faa3dcf3534cd885ce02182ee643f3e --- /dev/null +++ b/facefusion/face_detector.py @@ -0,0 +1,323 @@ +from functools import lru_cache +from typing import List, Sequence, Tuple + +import cv2 +import numpy + +from facefusion import inference_manager, state_manager +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.face_helper import create_rotated_matrix_and_size, create_static_anchors, distance_to_bounding_box, distance_to_face_landmark_5, normalize_bounding_box, transform_bounding_box, transform_points +from facefusion.filesystem import resolve_relative_path +from facefusion.thread_helper import thread_semaphore +from facefusion.types import Angle, BoundingBox, Detection, DownloadScope, DownloadSet, FaceLandmark5, InferencePool, ModelSet, Score, VisionFrame +from facefusion.vision import restrict_frame, unpack_resolution + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'retinaface': + { + 'hashes': + { + 'retinaface': + { + 'url': resolve_download_url('models-3.0.0', 'retinaface_10g.hash'), + 'path': resolve_relative_path('../.assets/models/retinaface_10g.hash') + } + }, + 'sources': + { + 'retinaface': + { + 'url': resolve_download_url('models-3.0.0', 'retinaface_10g.onnx'), + 'path': resolve_relative_path('../.assets/models/retinaface_10g.onnx') + } + } + }, + 'scrfd': + { + 'hashes': + { + 'scrfd': + { + 'url': resolve_download_url('models-3.0.0', 'scrfd_2.5g.hash'), + 'path': resolve_relative_path('../.assets/models/scrfd_2.5g.hash') + } + }, + 'sources': + { + 'scrfd': + { + 'url': resolve_download_url('models-3.0.0', 'scrfd_2.5g.onnx'), + 'path': resolve_relative_path('../.assets/models/scrfd_2.5g.onnx') + } + } + }, + 'yolo_face': + { + 'hashes': + { + 'yolo_face': + { + 'url': resolve_download_url('models-3.0.0', 'yoloface_8n.hash'), + 'path': resolve_relative_path('../.assets/models/yoloface_8n.hash') + } + }, + 'sources': + { + 'yolo_face': + { + 'url': resolve_download_url('models-3.0.0', 'yoloface_8n.onnx'), + 'path': resolve_relative_path('../.assets/models/yoloface_8n.onnx') + } + } + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_detector_model') ] + _, model_source_set = collect_model_downloads() + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('face_detector_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: + model_set = create_static_model_set('full') + model_hash_set = {} + model_source_set = {} + + for face_detector_model in [ 'retinaface', 'scrfd', 'yolo_face' ]: + if state_manager.get_item('face_detector_model') in [ 'many', face_detector_model ]: + model_hash_set[face_detector_model] = model_set.get(face_detector_model).get('hashes').get(face_detector_model) + model_source_set[face_detector_model] = model_set.get(face_detector_model).get('sources').get(face_detector_model) + + return model_hash_set, model_source_set + + +def pre_check() -> bool: + model_hash_set, model_source_set = collect_model_downloads() + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def detect_faces(vision_frame : VisionFrame) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]: + all_bounding_boxes : List[BoundingBox] = [] + all_face_scores : List[Score] = [] + all_face_landmarks_5 : List[FaceLandmark5] = [] + + if state_manager.get_item('face_detector_model') in [ 'many', 'retinaface' ]: + bounding_boxes, face_scores, face_landmarks_5 = detect_with_retinaface(vision_frame, state_manager.get_item('face_detector_size')) + all_bounding_boxes.extend(bounding_boxes) + all_face_scores.extend(face_scores) + all_face_landmarks_5.extend(face_landmarks_5) + + if state_manager.get_item('face_detector_model') in [ 'many', 'scrfd' ]: + bounding_boxes, face_scores, face_landmarks_5 = detect_with_scrfd(vision_frame, state_manager.get_item('face_detector_size')) + all_bounding_boxes.extend(bounding_boxes) + all_face_scores.extend(face_scores) + all_face_landmarks_5.extend(face_landmarks_5) + + if state_manager.get_item('face_detector_model') in [ 'many', 'yolo_face' ]: + bounding_boxes, face_scores, face_landmarks_5 = detect_with_yolo_face(vision_frame, state_manager.get_item('face_detector_size')) + all_bounding_boxes.extend(bounding_boxes) + all_face_scores.extend(face_scores) + all_face_landmarks_5.extend(face_landmarks_5) + + all_bounding_boxes = [ normalize_bounding_box(all_bounding_box) for all_bounding_box in all_bounding_boxes ] + return all_bounding_boxes, all_face_scores, all_face_landmarks_5 + + +def detect_rotated_faces(vision_frame : VisionFrame, angle : Angle) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]: + rotated_matrix, rotated_size = create_rotated_matrix_and_size(angle, vision_frame.shape[:2][::-1]) + rotated_vision_frame = cv2.warpAffine(vision_frame, rotated_matrix, rotated_size) + rotated_inverse_matrix = cv2.invertAffineTransform(rotated_matrix) + bounding_boxes, face_scores, face_landmarks_5 = detect_faces(rotated_vision_frame) + bounding_boxes = [ transform_bounding_box(bounding_box, rotated_inverse_matrix) for bounding_box in bounding_boxes ] + face_landmarks_5 = [ transform_points(face_landmark_5, rotated_inverse_matrix) for face_landmark_5 in face_landmarks_5 ] + return bounding_boxes, face_scores, face_landmarks_5 + + +def detect_with_retinaface(vision_frame : VisionFrame, face_detector_size : str) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]: + bounding_boxes = [] + face_scores = [] + face_landmarks_5 = [] + feature_strides = [ 8, 16, 32 ] + feature_map_channel = 3 + anchor_total = 2 + face_detector_score = state_manager.get_item('face_detector_score') + face_detector_width, face_detector_height = unpack_resolution(face_detector_size) + temp_vision_frame = restrict_frame(vision_frame, (face_detector_width, face_detector_height)) + ratio_height = vision_frame.shape[0] / temp_vision_frame.shape[0] + ratio_width = vision_frame.shape[1] / temp_vision_frame.shape[1] + detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size) + detect_vision_frame = normalize_detect_frame(detect_vision_frame, [ -1, 1 ]) + detection = forward_with_retinaface(detect_vision_frame) + + for index, feature_stride in enumerate(feature_strides): + keep_indices = numpy.where(detection[index] >= face_detector_score)[0] + + if numpy.any(keep_indices): + stride_height = face_detector_height // feature_stride + stride_width = face_detector_width // feature_stride + anchors = create_static_anchors(feature_stride, anchor_total, stride_height, stride_width) + bounding_boxes_raw = detection[index + feature_map_channel] * feature_stride + face_landmarks_5_raw = detection[index + feature_map_channel * 2] * feature_stride + + for bounding_box_raw in distance_to_bounding_box(anchors, bounding_boxes_raw)[keep_indices]: + bounding_boxes.append(numpy.array( + [ + bounding_box_raw[0] * ratio_width, + bounding_box_raw[1] * ratio_height, + bounding_box_raw[2] * ratio_width, + bounding_box_raw[3] * ratio_height + ])) + + for face_score_raw in detection[index][keep_indices]: + face_scores.append(face_score_raw[0]) + + for face_landmark_raw_5 in distance_to_face_landmark_5(anchors, face_landmarks_5_raw)[keep_indices]: + face_landmarks_5.append(face_landmark_raw_5 * [ ratio_width, ratio_height ]) + + return bounding_boxes, face_scores, face_landmarks_5 + + +def detect_with_scrfd(vision_frame : VisionFrame, face_detector_size : str) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]: + bounding_boxes = [] + face_scores = [] + face_landmarks_5 = [] + feature_strides = [ 8, 16, 32 ] + feature_map_channel = 3 + anchor_total = 2 + face_detector_score = state_manager.get_item('face_detector_score') + face_detector_width, face_detector_height = unpack_resolution(face_detector_size) + temp_vision_frame = restrict_frame(vision_frame, (face_detector_width, face_detector_height)) + ratio_height = vision_frame.shape[0] / temp_vision_frame.shape[0] + ratio_width = vision_frame.shape[1] / temp_vision_frame.shape[1] + detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size) + detect_vision_frame = normalize_detect_frame(detect_vision_frame, [ -1, 1 ]) + detection = forward_with_scrfd(detect_vision_frame) + + for index, feature_stride in enumerate(feature_strides): + keep_indices = numpy.where(detection[index] >= face_detector_score)[0] + + if numpy.any(keep_indices): + stride_height = face_detector_height // feature_stride + stride_width = face_detector_width // feature_stride + anchors = create_static_anchors(feature_stride, anchor_total, stride_height, stride_width) + bounding_boxes_raw = detection[index + feature_map_channel] * feature_stride + face_landmarks_5_raw = detection[index + feature_map_channel * 2] * feature_stride + + for bounding_box_raw in distance_to_bounding_box(anchors, bounding_boxes_raw)[keep_indices]: + bounding_boxes.append(numpy.array( + [ + bounding_box_raw[0] * ratio_width, + bounding_box_raw[1] * ratio_height, + bounding_box_raw[2] * ratio_width, + bounding_box_raw[3] * ratio_height + ])) + + for face_score_raw in detection[index][keep_indices]: + face_scores.append(face_score_raw[0]) + + for face_landmark_raw_5 in distance_to_face_landmark_5(anchors, face_landmarks_5_raw)[keep_indices]: + face_landmarks_5.append(face_landmark_raw_5 * [ ratio_width, ratio_height ]) + + return bounding_boxes, face_scores, face_landmarks_5 + + +def detect_with_yolo_face(vision_frame : VisionFrame, face_detector_size : str) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]: + bounding_boxes = [] + face_scores = [] + face_landmarks_5 = [] + face_detector_score = state_manager.get_item('face_detector_score') + face_detector_width, face_detector_height = unpack_resolution(face_detector_size) + temp_vision_frame = restrict_frame(vision_frame, (face_detector_width, face_detector_height)) + ratio_height = vision_frame.shape[0] / temp_vision_frame.shape[0] + ratio_width = vision_frame.shape[1] / temp_vision_frame.shape[1] + detect_vision_frame = prepare_detect_frame(temp_vision_frame, face_detector_size) + detect_vision_frame = normalize_detect_frame(detect_vision_frame, [ 0, 1 ]) + detection = forward_with_yolo_face(detect_vision_frame) + detection = numpy.squeeze(detection).T + bounding_boxes_raw, face_scores_raw, face_landmarks_5_raw = numpy.split(detection, [ 4, 5 ], axis = 1) + keep_indices = numpy.where(face_scores_raw > face_detector_score)[0] + + if numpy.any(keep_indices): + bounding_boxes_raw, face_scores_raw, face_landmarks_5_raw = bounding_boxes_raw[keep_indices], face_scores_raw[keep_indices], face_landmarks_5_raw[keep_indices] + + for bounding_box_raw in bounding_boxes_raw: + bounding_boxes.append(numpy.array( + [ + (bounding_box_raw[0] - bounding_box_raw[2] / 2) * ratio_width, + (bounding_box_raw[1] - bounding_box_raw[3] / 2) * ratio_height, + (bounding_box_raw[0] + bounding_box_raw[2] / 2) * ratio_width, + (bounding_box_raw[1] + bounding_box_raw[3] / 2) * ratio_height + ])) + + face_scores = face_scores_raw.ravel().tolist() + face_landmarks_5_raw[:, 0::3] = (face_landmarks_5_raw[:, 0::3]) * ratio_width + face_landmarks_5_raw[:, 1::3] = (face_landmarks_5_raw[:, 1::3]) * ratio_height + + for face_landmark_raw_5 in face_landmarks_5_raw: + face_landmarks_5.append(numpy.array(face_landmark_raw_5.reshape(-1, 3)[:, :2])) + + return bounding_boxes, face_scores, face_landmarks_5 + + +def forward_with_retinaface(detect_vision_frame : VisionFrame) -> Detection: + face_detector = get_inference_pool().get('retinaface') + + with thread_semaphore(): + detection = face_detector.run(None, + { + 'input': detect_vision_frame + }) + + return detection + + +def forward_with_scrfd(detect_vision_frame : VisionFrame) -> Detection: + face_detector = get_inference_pool().get('scrfd') + + with thread_semaphore(): + detection = face_detector.run(None, + { + 'input': detect_vision_frame + }) + + return detection + + +def forward_with_yolo_face(detect_vision_frame : VisionFrame) -> Detection: + face_detector = get_inference_pool().get('yolo_face') + + with thread_semaphore(): + detection = face_detector.run(None, + { + 'input': detect_vision_frame + }) + + return detection + + +def prepare_detect_frame(temp_vision_frame : VisionFrame, face_detector_size : str) -> VisionFrame: + face_detector_width, face_detector_height = unpack_resolution(face_detector_size) + detect_vision_frame = numpy.zeros((face_detector_height, face_detector_width, 3)) + detect_vision_frame[:temp_vision_frame.shape[0], :temp_vision_frame.shape[1], :] = temp_vision_frame + detect_vision_frame = numpy.expand_dims(detect_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) + return detect_vision_frame + + +def normalize_detect_frame(detect_vision_frame : VisionFrame, normalize_range : Sequence[int]) -> VisionFrame: + if normalize_range == [ -1, 1 ]: + return (detect_vision_frame - 127.5) / 128.0 + if normalize_range == [ 0, 1 ]: + return detect_vision_frame / 255.0 + return detect_vision_frame diff --git a/facefusion/face_helper.py b/facefusion/face_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..cc37fe01f6f39e5e76a46837207d7cf04e3773a2 --- /dev/null +++ b/facefusion/face_helper.py @@ -0,0 +1,254 @@ +from functools import lru_cache +from typing import List, Sequence, Tuple + +import cv2 +import numpy +from cv2.typing import Size + +from facefusion.types import Anchors, Angle, BoundingBox, Distance, FaceDetectorModel, FaceLandmark5, FaceLandmark68, Mask, Matrix, Points, Scale, Score, Translation, VisionFrame, WarpTemplate, WarpTemplateSet + +WARP_TEMPLATE_SET : WarpTemplateSet =\ +{ + 'arcface_112_v1': numpy.array( + [ + [ 0.35473214, 0.45658929 ], + [ 0.64526786, 0.45658929 ], + [ 0.50000000, 0.61154464 ], + [ 0.37913393, 0.77687500 ], + [ 0.62086607, 0.77687500 ] + ]), + 'arcface_112_v2': numpy.array( + [ + [ 0.34191607, 0.46157411 ], + [ 0.65653393, 0.45983393 ], + [ 0.50022500, 0.64050536 ], + [ 0.37097589, 0.82469196 ], + [ 0.63151696, 0.82325089 ] + ]), + 'arcface_128': numpy.array( + [ + [ 0.36167656, 0.40387734 ], + [ 0.63696719, 0.40235469 ], + [ 0.50019687, 0.56044219 ], + [ 0.38710391, 0.72160547 ], + [ 0.61507734, 0.72034453 ] + ]), + 'dfl_whole_face': numpy.array( + [ + [ 0.35342266, 0.39285716 ], + [ 0.62797622, 0.39285716 ], + [ 0.48660713, 0.54017860 ], + [ 0.38839287, 0.68750011 ], + [ 0.59821427, 0.68750011 ] + ]), + 'ffhq_512': numpy.array( + [ + [ 0.37691676, 0.46864664 ], + [ 0.62285697, 0.46912813 ], + [ 0.50123859, 0.61331904 ], + [ 0.39308822, 0.72541100 ], + [ 0.61150205, 0.72490465 ] + ]), + 'mtcnn_512': numpy.array( + [ + [ 0.36562865, 0.46733799 ], + [ 0.63305391, 0.46585885 ], + [ 0.50019127, 0.61942959 ], + [ 0.39032951, 0.77598822 ], + [ 0.61178945, 0.77476328 ] + ]), + 'styleganex_384': numpy.array( + [ + [ 0.42353745, 0.52289879 ], + [ 0.57725008, 0.52319972 ], + [ 0.50123859, 0.61331904 ], + [ 0.43364461, 0.68337652 ], + [ 0.57015325, 0.68306005 ] + ]) +} + + +def estimate_matrix_by_face_landmark_5(face_landmark_5 : FaceLandmark5, warp_template : WarpTemplate, crop_size : Size) -> Matrix: + normed_warp_template = WARP_TEMPLATE_SET.get(warp_template) * crop_size + affine_matrix = cv2.estimateAffinePartial2D(face_landmark_5, normed_warp_template, method = cv2.RANSAC, ransacReprojThreshold = 100)[0] + return affine_matrix + + +def warp_face_by_face_landmark_5(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5, warp_template : WarpTemplate, crop_size : Size) -> Tuple[VisionFrame, Matrix]: + affine_matrix = estimate_matrix_by_face_landmark_5(face_landmark_5, warp_template, crop_size) + crop_vision_frame = cv2.warpAffine(temp_vision_frame, affine_matrix, crop_size, borderMode = cv2.BORDER_REPLICATE, flags = cv2.INTER_AREA) + return crop_vision_frame, affine_matrix + + +def warp_face_by_bounding_box(temp_vision_frame : VisionFrame, bounding_box : BoundingBox, crop_size : Size) -> Tuple[VisionFrame, Matrix]: + source_points = numpy.array([ [ bounding_box[0], bounding_box[1] ], [bounding_box[2], bounding_box[1] ], [ bounding_box[0], bounding_box[3] ] ]).astype(numpy.float32) + target_points = numpy.array([ [ 0, 0 ], [ crop_size[0], 0 ], [ 0, crop_size[1] ] ]).astype(numpy.float32) + affine_matrix = cv2.getAffineTransform(source_points, target_points) + if bounding_box[2] - bounding_box[0] > crop_size[0] or bounding_box[3] - bounding_box[1] > crop_size[1]: + interpolation_method = cv2.INTER_AREA + else: + interpolation_method = cv2.INTER_LINEAR + crop_vision_frame = cv2.warpAffine(temp_vision_frame, affine_matrix, crop_size, flags = interpolation_method) + return crop_vision_frame, affine_matrix + + +def warp_face_by_translation(temp_vision_frame : VisionFrame, translation : Translation, scale : float, crop_size : Size) -> Tuple[VisionFrame, Matrix]: + affine_matrix = numpy.array([ [ scale, 0, translation[0] ], [ 0, scale, translation[1] ] ]) + crop_vision_frame = cv2.warpAffine(temp_vision_frame, affine_matrix, crop_size) + return crop_vision_frame, affine_matrix + + +def paste_back(temp_vision_frame : VisionFrame, crop_vision_frame : VisionFrame, crop_mask : Mask, affine_matrix : Matrix) -> VisionFrame: + paste_bounding_box, paste_matrix = calc_paste_area(temp_vision_frame, crop_vision_frame, affine_matrix) + x_min, y_min, x_max, y_max = paste_bounding_box + paste_width = x_max - x_min + paste_height = y_max - y_min + inverse_mask = cv2.warpAffine(crop_mask, paste_matrix, (paste_width, paste_height)).clip(0, 1) + inverse_mask = numpy.expand_dims(inverse_mask, axis = -1) + inverse_vision_frame = cv2.warpAffine(crop_vision_frame, paste_matrix, (paste_width, paste_height), borderMode = cv2.BORDER_REPLICATE) + temp_vision_frame = temp_vision_frame.copy() + paste_vision_frame = temp_vision_frame[y_min:y_max, x_min:x_max] + paste_vision_frame = paste_vision_frame * (1 - inverse_mask) + inverse_vision_frame * inverse_mask + temp_vision_frame[y_min:y_max, x_min:x_max] = paste_vision_frame.astype(temp_vision_frame.dtype) + return temp_vision_frame + + +def calc_paste_area(temp_vision_frame : VisionFrame, crop_vision_frame : VisionFrame, affine_matrix : Matrix) -> Tuple[BoundingBox, Matrix]: + temp_height, temp_width = temp_vision_frame.shape[:2] + crop_height, crop_width = crop_vision_frame.shape[:2] + inverse_matrix = cv2.invertAffineTransform(affine_matrix) + crop_points = numpy.array([ [ 0, 0 ], [ crop_width, 0 ], [ crop_width, crop_height ], [ 0, crop_height ] ]) + paste_region_points = transform_points(crop_points, inverse_matrix) + min_point = numpy.floor(paste_region_points.min(axis = 0)).astype(int) + max_point = numpy.ceil(paste_region_points.max(axis = 0)).astype(int) + x_min, y_min = numpy.clip(min_point, 0, [ temp_width, temp_height ]) + x_max, y_max = numpy.clip(max_point, 0, [ temp_width, temp_height ]) + paste_bounding_box = numpy.array([ x_min, y_min, x_max, y_max ]) + paste_matrix = inverse_matrix.copy() + paste_matrix[0, 2] -= x_min + paste_matrix[1, 2] -= y_min + return paste_bounding_box, paste_matrix + + +@lru_cache(maxsize = None) +def create_static_anchors(feature_stride : int, anchor_total : int, stride_height : int, stride_width : int) -> Anchors: + y, x = numpy.mgrid[:stride_height, :stride_width][::-1] + anchors = numpy.stack((y, x), axis = -1) + anchors = (anchors * feature_stride).reshape((-1, 2)) + anchors = numpy.stack([ anchors ] * anchor_total, axis = 1).reshape((-1, 2)) + return anchors + + +def create_rotated_matrix_and_size(angle : Angle, size : Size) -> Tuple[Matrix, Size]: + rotated_matrix = cv2.getRotationMatrix2D((size[0] / 2, size[1] / 2), angle, 1) + rotated_size = numpy.dot(numpy.abs(rotated_matrix[:, :2]), size) + rotated_matrix[:, -1] += (rotated_size - size) * 0.5 #type:ignore[misc] + rotated_size = int(rotated_size[0]), int(rotated_size[1]) + return rotated_matrix, rotated_size + + +def create_bounding_box(face_landmark_68 : FaceLandmark68) -> BoundingBox: + min_x, min_y = numpy.min(face_landmark_68, axis = 0) + max_x, max_y = numpy.max(face_landmark_68, axis = 0) + bounding_box = normalize_bounding_box(numpy.array([ min_x, min_y, max_x, max_y ])) + return bounding_box + + +def normalize_bounding_box(bounding_box : BoundingBox) -> BoundingBox: + x1, y1, x2, y2 = bounding_box + x1, x2 = sorted([ x1, x2 ]) + y1, y2 = sorted([ y1, y2 ]) + return numpy.array([ x1, y1, x2, y2 ]) + + +def transform_points(points : Points, matrix : Matrix) -> Points: + points = points.reshape(-1, 1, 2) + points = cv2.transform(points, matrix) #type:ignore[assignment] + points = points.reshape(-1, 2) + return points + + +def transform_bounding_box(bounding_box : BoundingBox, matrix : Matrix) -> BoundingBox: + points = numpy.array( + [ + [ bounding_box[0], bounding_box[1] ], + [ bounding_box[2], bounding_box[1] ], + [ bounding_box[2], bounding_box[3] ], + [ bounding_box[0], bounding_box[3] ] + ]) + points = transform_points(points, matrix) + x1, y1 = numpy.min(points, axis = 0) + x2, y2 = numpy.max(points, axis = 0) + return normalize_bounding_box(numpy.array([ x1, y1, x2, y2 ])) + + +def distance_to_bounding_box(points : Points, distance : Distance) -> BoundingBox: + x1 = points[:, 0] - distance[:, 0] + y1 = points[:, 1] - distance[:, 1] + x2 = points[:, 0] + distance[:, 2] + y2 = points[:, 1] + distance[:, 3] + bounding_box = numpy.column_stack([ x1, y1, x2, y2 ]) + return bounding_box + + +def distance_to_face_landmark_5(points : Points, distance : Distance) -> FaceLandmark5: + x = points[:, 0::2] + distance[:, 0::2] + y = points[:, 1::2] + distance[:, 1::2] + face_landmark_5 = numpy.stack((x, y), axis = -1) + return face_landmark_5 + + +def scale_face_landmark_5(face_landmark_5 : FaceLandmark5, scale : Scale) -> FaceLandmark5: + face_landmark_5_scale = face_landmark_5 - face_landmark_5[2] + face_landmark_5_scale *= scale + face_landmark_5_scale += face_landmark_5[2] + return face_landmark_5_scale + + +def convert_to_face_landmark_5(face_landmark_68 : FaceLandmark68) -> FaceLandmark5: + face_landmark_5 = numpy.array( + [ + numpy.mean(face_landmark_68[36:42], axis = 0), + numpy.mean(face_landmark_68[42:48], axis = 0), + face_landmark_68[30], + face_landmark_68[48], + face_landmark_68[54] + ]) + return face_landmark_5 + + +def estimate_face_angle(face_landmark_68 : FaceLandmark68) -> Angle: + x1, y1 = face_landmark_68[0] + x2, y2 = face_landmark_68[16] + theta = numpy.arctan2(y2 - y1, x2 - x1) + theta = numpy.degrees(theta) % 360 + angles = numpy.linspace(0, 360, 5) + index = numpy.argmin(numpy.abs(angles - theta)) + face_angle = int(angles[index] % 360) + return face_angle + + +def apply_nms(bounding_boxes : List[BoundingBox], scores : List[Score], score_threshold : float, nms_threshold : float) -> Sequence[int]: + normed_bounding_boxes = [ (x1, y1, x2 - x1, y2 - y1) for (x1, y1, x2, y2) in bounding_boxes ] + keep_indices = cv2.dnn.NMSBoxes(normed_bounding_boxes, scores, score_threshold = score_threshold, nms_threshold = nms_threshold) + return keep_indices + + +def get_nms_threshold(face_detector_model : FaceDetectorModel, face_detector_angles : List[Angle]) -> float: + if face_detector_model == 'many': + return 0.1 + if len(face_detector_angles) == 2: + return 0.3 + if len(face_detector_angles) == 3: + return 0.2 + if len(face_detector_angles) == 4: + return 0.1 + return 0.4 + + +def merge_matrix(matrices : List[Matrix]) -> Matrix: + merged_matrix = numpy.vstack([ matrices[0], [ 0, 0, 1 ] ]) + for matrix in matrices[1:]: + matrix = numpy.vstack([ matrix, [ 0, 0, 1 ] ]) + merged_matrix = numpy.dot(merged_matrix, matrix) + return merged_matrix[:2, :] diff --git a/facefusion/face_landmarker.py b/facefusion/face_landmarker.py new file mode 100644 index 0000000000000000000000000000000000000000..cab96277c5ae7e20fab29e097062b86840ddf085 --- /dev/null +++ b/facefusion/face_landmarker.py @@ -0,0 +1,222 @@ +from functools import lru_cache +from typing import Tuple + +import cv2 +import numpy + +from facefusion import inference_manager, state_manager +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.face_helper import create_rotated_matrix_and_size, estimate_matrix_by_face_landmark_5, transform_points, warp_face_by_translation +from facefusion.filesystem import resolve_relative_path +from facefusion.thread_helper import conditional_thread_semaphore +from facefusion.types import Angle, BoundingBox, DownloadScope, DownloadSet, FaceLandmark5, FaceLandmark68, InferencePool, ModelSet, Prediction, Score, VisionFrame + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + '2dfan4': + { + 'hashes': + { + '2dfan4': + { + 'url': resolve_download_url('models-3.0.0', '2dfan4.hash'), + 'path': resolve_relative_path('../.assets/models/2dfan4.hash') + } + }, + 'sources': + { + '2dfan4': + { + 'url': resolve_download_url('models-3.0.0', '2dfan4.onnx'), + 'path': resolve_relative_path('../.assets/models/2dfan4.onnx') + } + }, + 'size': (256, 256) + }, + 'peppa_wutz': + { + 'hashes': + { + 'peppa_wutz': + { + 'url': resolve_download_url('models-3.0.0', 'peppa_wutz.hash'), + 'path': resolve_relative_path('../.assets/models/peppa_wutz.hash') + } + }, + 'sources': + { + 'peppa_wutz': + { + 'url': resolve_download_url('models-3.0.0', 'peppa_wutz.onnx'), + 'path': resolve_relative_path('../.assets/models/peppa_wutz.onnx') + } + }, + 'size': (256, 256) + }, + 'fan_68_5': + { + 'hashes': + { + 'fan_68_5': + { + 'url': resolve_download_url('models-3.0.0', 'fan_68_5.hash'), + 'path': resolve_relative_path('../.assets/models/fan_68_5.hash') + } + }, + 'sources': + { + 'fan_68_5': + { + 'url': resolve_download_url('models-3.0.0', 'fan_68_5.onnx'), + 'path': resolve_relative_path('../.assets/models/fan_68_5.onnx') + } + } + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ] + _, model_source_set = collect_model_downloads() + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: + model_set = create_static_model_set('full') + model_hash_set =\ + { + 'fan_68_5': model_set.get('fan_68_5').get('hashes').get('fan_68_5') + } + model_source_set =\ + { + 'fan_68_5': model_set.get('fan_68_5').get('sources').get('fan_68_5') + } + + for face_landmarker_model in [ '2dfan4', 'peppa_wutz' ]: + if state_manager.get_item('face_landmarker_model') in [ 'many', face_landmarker_model ]: + model_hash_set[face_landmarker_model] = model_set.get(face_landmarker_model).get('hashes').get(face_landmarker_model) + model_source_set[face_landmarker_model] = model_set.get(face_landmarker_model).get('sources').get(face_landmarker_model) + + return model_hash_set, model_source_set + + +def pre_check() -> bool: + model_hash_set, model_source_set = collect_model_downloads() + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def detect_face_landmark(vision_frame : VisionFrame, bounding_box : BoundingBox, face_angle : Angle) -> Tuple[FaceLandmark68, Score]: + face_landmark_2dfan4 = None + face_landmark_peppa_wutz = None + face_landmark_score_2dfan4 = 0.0 + face_landmark_score_peppa_wutz = 0.0 + + if state_manager.get_item('face_landmarker_model') in [ 'many', '2dfan4' ]: + face_landmark_2dfan4, face_landmark_score_2dfan4 = detect_with_2dfan4(vision_frame, bounding_box, face_angle) + + if state_manager.get_item('face_landmarker_model') in [ 'many', 'peppa_wutz' ]: + face_landmark_peppa_wutz, face_landmark_score_peppa_wutz = detect_with_peppa_wutz(vision_frame, bounding_box, face_angle) + + if face_landmark_score_2dfan4 > face_landmark_score_peppa_wutz - 0.2: + return face_landmark_2dfan4, face_landmark_score_2dfan4 + return face_landmark_peppa_wutz, face_landmark_score_peppa_wutz + + +def detect_with_2dfan4(temp_vision_frame: VisionFrame, bounding_box: BoundingBox, face_angle: Angle) -> Tuple[FaceLandmark68, Score]: + model_size = create_static_model_set('full').get('2dfan4').get('size') + scale = 195 / numpy.subtract(bounding_box[2:], bounding_box[:2]).max().clip(1, None) + translation = (model_size[0] - numpy.add(bounding_box[2:], bounding_box[:2]) * scale) * 0.5 + rotated_matrix, rotated_size = create_rotated_matrix_and_size(face_angle, model_size) + crop_vision_frame, affine_matrix = warp_face_by_translation(temp_vision_frame, translation, scale, model_size) + crop_vision_frame = cv2.warpAffine(crop_vision_frame, rotated_matrix, rotated_size) + crop_vision_frame = conditional_optimize_contrast(crop_vision_frame) + crop_vision_frame = crop_vision_frame.transpose(2, 0, 1).astype(numpy.float32) / 255.0 + face_landmark_68, face_heatmap = forward_with_2dfan4(crop_vision_frame) + face_landmark_68 = face_landmark_68[:, :, :2][0] / 64 * 256 + face_landmark_68 = transform_points(face_landmark_68, cv2.invertAffineTransform(rotated_matrix)) + face_landmark_68 = transform_points(face_landmark_68, cv2.invertAffineTransform(affine_matrix)) + face_landmark_score_68 = numpy.amax(face_heatmap, axis = (2, 3)) + face_landmark_score_68 = numpy.mean(face_landmark_score_68) + face_landmark_score_68 = numpy.interp(face_landmark_score_68, [ 0, 0.9 ], [ 0, 1 ]) + return face_landmark_68, face_landmark_score_68 + + +def detect_with_peppa_wutz(temp_vision_frame : VisionFrame, bounding_box : BoundingBox, face_angle : Angle) -> Tuple[FaceLandmark68, Score]: + model_size = create_static_model_set('full').get('peppa_wutz').get('size') + scale = 195 / numpy.subtract(bounding_box[2:], bounding_box[:2]).max().clip(1, None) + translation = (model_size[0] - numpy.add(bounding_box[2:], bounding_box[:2]) * scale) * 0.5 + rotated_matrix, rotated_size = create_rotated_matrix_and_size(face_angle, model_size) + crop_vision_frame, affine_matrix = warp_face_by_translation(temp_vision_frame, translation, scale, model_size) + crop_vision_frame = cv2.warpAffine(crop_vision_frame, rotated_matrix, rotated_size) + crop_vision_frame = conditional_optimize_contrast(crop_vision_frame) + crop_vision_frame = crop_vision_frame.transpose(2, 0, 1).astype(numpy.float32) / 255.0 + crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0) + prediction = forward_with_peppa_wutz(crop_vision_frame) + face_landmark_68 = prediction.reshape(-1, 3)[:, :2] / 64 * model_size[0] + face_landmark_68 = transform_points(face_landmark_68, cv2.invertAffineTransform(rotated_matrix)) + face_landmark_68 = transform_points(face_landmark_68, cv2.invertAffineTransform(affine_matrix)) + face_landmark_score_68 = prediction.reshape(-1, 3)[:, 2].mean() + face_landmark_score_68 = numpy.interp(face_landmark_score_68, [ 0, 0.95 ], [ 0, 1 ]) + return face_landmark_68, face_landmark_score_68 + + +def conditional_optimize_contrast(crop_vision_frame : VisionFrame) -> VisionFrame: + crop_vision_frame = cv2.cvtColor(crop_vision_frame, cv2.COLOR_RGB2Lab) + if numpy.mean(crop_vision_frame[:, :, 0]) < 30: #type:ignore[arg-type] + crop_vision_frame[:, :, 0] = cv2.createCLAHE(clipLimit = 2).apply(crop_vision_frame[:, :, 0]) + crop_vision_frame = cv2.cvtColor(crop_vision_frame, cv2.COLOR_Lab2RGB) + return crop_vision_frame + + +def estimate_face_landmark_68_5(face_landmark_5 : FaceLandmark5) -> FaceLandmark68: + affine_matrix = estimate_matrix_by_face_landmark_5(face_landmark_5, 'ffhq_512', (1, 1)) + face_landmark_5 = cv2.transform(face_landmark_5.reshape(1, -1, 2), affine_matrix).reshape(-1, 2) + face_landmark_68_5 = forward_fan_68_5(face_landmark_5) + face_landmark_68_5 = cv2.transform(face_landmark_68_5.reshape(1, -1, 2), cv2.invertAffineTransform(affine_matrix)).reshape(-1, 2) + return face_landmark_68_5 + + +def forward_with_2dfan4(crop_vision_frame : VisionFrame) -> Tuple[Prediction, Prediction]: + face_landmarker = get_inference_pool().get('2dfan4') + + with conditional_thread_semaphore(): + prediction = face_landmarker.run(None, + { + 'input': [ crop_vision_frame ] + }) + + return prediction + + +def forward_with_peppa_wutz(crop_vision_frame : VisionFrame) -> Prediction: + face_landmarker = get_inference_pool().get('peppa_wutz') + + with conditional_thread_semaphore(): + prediction = face_landmarker.run(None, + { + 'input': crop_vision_frame + })[0] + + return prediction + + +def forward_fan_68_5(face_landmark_5 : FaceLandmark5) -> FaceLandmark68: + face_landmarker = get_inference_pool().get('fan_68_5') + + with conditional_thread_semaphore(): + face_landmark_68_5 = face_landmarker.run(None, + { + 'input': [ face_landmark_5 ] + })[0][0] + + return face_landmark_68_5 diff --git a/facefusion/face_masker.py b/facefusion/face_masker.py new file mode 100644 index 0000000000000000000000000000000000000000..400838d1c31122ea9ade919b2bd0e6372b9e1078 --- /dev/null +++ b/facefusion/face_masker.py @@ -0,0 +1,240 @@ +from functools import lru_cache +from typing import List, Tuple + +import cv2 +import numpy + +import facefusion.choices +from facefusion import inference_manager, state_manager +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.filesystem import resolve_relative_path +from facefusion.thread_helper import conditional_thread_semaphore +from facefusion.types import DownloadScope, DownloadSet, FaceLandmark68, FaceMaskArea, FaceMaskRegion, InferencePool, Mask, ModelSet, Padding, VisionFrame + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'xseg_1': + { + 'hashes': + { + 'face_occluder': + { + 'url': resolve_download_url('models-3.1.0', 'xseg_1.hash'), + 'path': resolve_relative_path('../.assets/models/xseg_1.hash') + } + }, + 'sources': + { + 'face_occluder': + { + 'url': resolve_download_url('models-3.1.0', 'xseg_1.onnx'), + 'path': resolve_relative_path('../.assets/models/xseg_1.onnx') + } + }, + 'size': (256, 256) + }, + 'xseg_2': + { + 'hashes': + { + 'face_occluder': + { + 'url': resolve_download_url('models-3.1.0', 'xseg_2.hash'), + 'path': resolve_relative_path('../.assets/models/xseg_2.hash') + } + }, + 'sources': + { + 'face_occluder': + { + 'url': resolve_download_url('models-3.1.0', 'xseg_2.onnx'), + 'path': resolve_relative_path('../.assets/models/xseg_2.onnx') + } + }, + 'size': (256, 256) + }, + 'xseg_3': + { + 'hashes': + { + 'face_occluder': + { + 'url': resolve_download_url('models-3.2.0', 'xseg_3.hash'), + 'path': resolve_relative_path('../.assets/models/xseg_3.hash') + } + }, + 'sources': + { + 'face_occluder': + { + 'url': resolve_download_url('models-3.2.0', 'xseg_3.onnx'), + 'path': resolve_relative_path('../.assets/models/xseg_3.onnx') + } + }, + 'size': (256, 256) + }, + 'bisenet_resnet_18': + { + 'hashes': + { + 'face_parser': + { + 'url': resolve_download_url('models-3.1.0', 'bisenet_resnet_18.hash'), + 'path': resolve_relative_path('../.assets/models/bisenet_resnet_18.hash') + } + }, + 'sources': + { + 'face_parser': + { + 'url': resolve_download_url('models-3.1.0', 'bisenet_resnet_18.onnx'), + 'path': resolve_relative_path('../.assets/models/bisenet_resnet_18.onnx') + } + }, + 'size': (512, 512) + }, + 'bisenet_resnet_34': + { + 'hashes': + { + 'face_parser': + { + 'url': resolve_download_url('models-3.0.0', 'bisenet_resnet_34.hash'), + 'path': resolve_relative_path('../.assets/models/bisenet_resnet_34.hash') + } + }, + 'sources': + { + 'face_parser': + { + 'url': resolve_download_url('models-3.0.0', 'bisenet_resnet_34.onnx'), + 'path': resolve_relative_path('../.assets/models/bisenet_resnet_34.onnx') + } + }, + 'size': (512, 512) + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model') ] + _, model_source_set = collect_model_downloads() + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: + model_set = create_static_model_set('full') + model_hash_set = {} + model_source_set = {} + + for face_occluder_model in [ 'xseg_1', 'xseg_2', 'xseg_3' ]: + if state_manager.get_item('face_occluder_model') == face_occluder_model: + model_hash_set[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder') + model_source_set[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder') + + for face_parser_model in [ 'bisenet_resnet_18', 'bisenet_resnet_34' ]: + if state_manager.get_item('face_parser_model') == face_parser_model: + model_hash_set[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser') + model_source_set[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser') + + return model_hash_set, model_source_set + + +def pre_check() -> bool: + model_hash_set, model_source_set = collect_model_downloads() + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def create_box_mask(crop_vision_frame : VisionFrame, face_mask_blur : float, face_mask_padding : Padding) -> Mask: + crop_size = crop_vision_frame.shape[:2][::-1] + blur_amount = int(crop_size[0] * 0.5 * face_mask_blur) + blur_area = max(blur_amount // 2, 1) + box_mask : Mask = numpy.ones(crop_size).astype(numpy.float32) + box_mask[:max(blur_area, int(crop_size[1] * face_mask_padding[0] / 100)), :] = 0 + box_mask[-max(blur_area, int(crop_size[1] * face_mask_padding[2] / 100)):, :] = 0 + box_mask[:, :max(blur_area, int(crop_size[0] * face_mask_padding[3] / 100))] = 0 + box_mask[:, -max(blur_area, int(crop_size[0] * face_mask_padding[1] / 100)):] = 0 + + if blur_amount > 0: + box_mask = cv2.GaussianBlur(box_mask, (0, 0), blur_amount * 0.25) + return box_mask + + +def create_occlusion_mask(crop_vision_frame : VisionFrame) -> Mask: + model_name = state_manager.get_item('face_occluder_model') + model_size = create_static_model_set('full').get(model_name).get('size') + prepare_vision_frame = cv2.resize(crop_vision_frame, model_size) + prepare_vision_frame = numpy.expand_dims(prepare_vision_frame, axis = 0).astype(numpy.float32) / 255.0 + prepare_vision_frame = prepare_vision_frame.transpose(0, 1, 2, 3) + occlusion_mask = forward_occlude_face(prepare_vision_frame) + occlusion_mask = occlusion_mask.transpose(0, 1, 2).clip(0, 1).astype(numpy.float32) + occlusion_mask = cv2.resize(occlusion_mask, crop_vision_frame.shape[:2][::-1]) + occlusion_mask = (cv2.GaussianBlur(occlusion_mask.clip(0, 1), (0, 0), 5).clip(0.5, 1) - 0.5) * 2 + return occlusion_mask + + +def create_area_mask(crop_vision_frame : VisionFrame, face_landmark_68 : FaceLandmark68, face_mask_areas : List[FaceMaskArea]) -> Mask: + crop_size = crop_vision_frame.shape[:2][::-1] + landmark_points = [] + + for face_mask_area in face_mask_areas: + if face_mask_area in facefusion.choices.face_mask_area_set: + landmark_points.extend(facefusion.choices.face_mask_area_set.get(face_mask_area)) + + convex_hull = cv2.convexHull(face_landmark_68[landmark_points].astype(numpy.int32)) + area_mask = numpy.zeros(crop_size).astype(numpy.float32) + cv2.fillConvexPoly(area_mask, convex_hull, 1.0) # type: ignore[call-overload] + area_mask = (cv2.GaussianBlur(area_mask.clip(0, 1), (0, 0), 5).clip(0.5, 1) - 0.5) * 2 + return area_mask + + +def create_region_mask(crop_vision_frame : VisionFrame, face_mask_regions : List[FaceMaskRegion]) -> Mask: + model_name = state_manager.get_item('face_parser_model') + model_size = create_static_model_set('full').get(model_name).get('size') + prepare_vision_frame = cv2.resize(crop_vision_frame, model_size) + prepare_vision_frame = prepare_vision_frame[:, :, ::-1].astype(numpy.float32) / 255.0 + prepare_vision_frame = numpy.subtract(prepare_vision_frame, numpy.array([ 0.485, 0.456, 0.406 ]).astype(numpy.float32)) + prepare_vision_frame = numpy.divide(prepare_vision_frame, numpy.array([ 0.229, 0.224, 0.225 ]).astype(numpy.float32)) + prepare_vision_frame = numpy.expand_dims(prepare_vision_frame, axis = 0) + prepare_vision_frame = prepare_vision_frame.transpose(0, 3, 1, 2) + region_mask = forward_parse_face(prepare_vision_frame) + region_mask = numpy.isin(region_mask.argmax(0), [ facefusion.choices.face_mask_region_set.get(face_mask_region) for face_mask_region in face_mask_regions ]) + region_mask = cv2.resize(region_mask.astype(numpy.float32), crop_vision_frame.shape[:2][::-1]) + region_mask = (cv2.GaussianBlur(region_mask.clip(0, 1), (0, 0), 5).clip(0.5, 1) - 0.5) * 2 + return region_mask + + +def forward_occlude_face(prepare_vision_frame : VisionFrame) -> Mask: + model_name = state_manager.get_item('face_occluder_model') + face_occluder = get_inference_pool().get(model_name) + + with conditional_thread_semaphore(): + occlusion_mask : Mask = face_occluder.run(None, + { + 'input': prepare_vision_frame + })[0][0] + + return occlusion_mask + + +def forward_parse_face(prepare_vision_frame : VisionFrame) -> Mask: + model_name = state_manager.get_item('face_parser_model') + face_parser = get_inference_pool().get(model_name) + + with conditional_thread_semaphore(): + region_mask : Mask = face_parser.run(None, + { + 'input': prepare_vision_frame + })[0][0] + + return region_mask diff --git a/facefusion/face_recognizer.py b/facefusion/face_recognizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c2890263e12e19f9301ab806e1be4f55e5e7287c --- /dev/null +++ b/facefusion/face_recognizer.py @@ -0,0 +1,87 @@ +from functools import lru_cache +from typing import Tuple + +import numpy + +from facefusion import inference_manager +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.face_helper import warp_face_by_face_landmark_5 +from facefusion.filesystem import resolve_relative_path +from facefusion.thread_helper import conditional_thread_semaphore +from facefusion.types import DownloadScope, Embedding, FaceLandmark5, InferencePool, ModelOptions, ModelSet, VisionFrame + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'arcface': + { + 'hashes': + { + 'face_recognizer': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_w600k_r50.hash'), + 'path': resolve_relative_path('../.assets/models/arcface_w600k_r50.hash') + } + }, + 'sources': + { + 'face_recognizer': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_w600k_r50.onnx'), + 'path': resolve_relative_path('../.assets/models/arcface_w600k_r50.onnx') + } + }, + 'template': 'arcface_112_v2', + 'size': (112, 112) + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ 'arcface' ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ 'arcface' ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + return create_static_model_set('full').get('arcface') + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def calc_embedding(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Embedding, Embedding]: + model_template = get_model_options().get('template') + model_size = get_model_options().get('size') + crop_vision_frame, matrix = warp_face_by_face_landmark_5(temp_vision_frame, face_landmark_5, model_template, model_size) + crop_vision_frame = crop_vision_frame / 127.5 - 1 + crop_vision_frame = crop_vision_frame[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) + crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0) + embedding = forward(crop_vision_frame) + embedding = embedding.ravel() + normed_embedding = embedding / numpy.linalg.norm(embedding) + return embedding, normed_embedding + + +def forward(crop_vision_frame : VisionFrame) -> Embedding: + face_recognizer = get_inference_pool().get('face_recognizer') + + with conditional_thread_semaphore(): + embedding = face_recognizer.run(None, + { + 'input': crop_vision_frame + })[0] + + return embedding diff --git a/facefusion/face_selector.py b/facefusion/face_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0341eabc2205d4c67f187eeefdfc865f174b10 --- /dev/null +++ b/facefusion/face_selector.py @@ -0,0 +1,108 @@ +from typing import List + +import numpy + +from facefusion import state_manager +from facefusion.types import Face, FaceSelectorOrder, FaceSet, Gender, Race, Score + + +def find_similar_faces(faces : List[Face], reference_faces : FaceSet, face_distance : float) -> List[Face]: + similar_faces : List[Face] = [] + + if faces and reference_faces: + for reference_set in reference_faces: + if not similar_faces: + for reference_face in reference_faces[reference_set]: + for face in faces: + if compare_faces(face, reference_face, face_distance): + similar_faces.append(face) + return similar_faces + + +def compare_faces(face : Face, reference_face : Face, face_distance : float) -> bool: + current_face_distance = calc_face_distance(face, reference_face) + current_face_distance = float(numpy.interp(current_face_distance, [ 0, 2 ], [ 0, 1 ])) + return current_face_distance < face_distance + + +def calc_face_distance(face : Face, reference_face : Face) -> float: + if hasattr(face, 'normed_embedding') and hasattr(reference_face, 'normed_embedding'): + return 1 - numpy.dot(face.normed_embedding, reference_face.normed_embedding) + return 0 + + +def sort_and_filter_faces(faces : List[Face]) -> List[Face]: + if faces: + if state_manager.get_item('face_selector_order'): + faces = sort_faces_by_order(faces, state_manager.get_item('face_selector_order')) + if state_manager.get_item('face_selector_gender'): + faces = filter_faces_by_gender(faces, state_manager.get_item('face_selector_gender')) + if state_manager.get_item('face_selector_race'): + faces = filter_faces_by_race(faces, state_manager.get_item('face_selector_race')) + if state_manager.get_item('face_selector_age_start') or state_manager.get_item('face_selector_age_end'): + faces = filter_faces_by_age(faces, state_manager.get_item('face_selector_age_start'), state_manager.get_item('face_selector_age_end')) + return faces + + +def sort_faces_by_order(faces : List[Face], order : FaceSelectorOrder) -> List[Face]: + if order == 'left-right': + return sorted(faces, key = get_bounding_box_left) + if order == 'right-left': + return sorted(faces, key = get_bounding_box_left, reverse = True) + if order == 'top-bottom': + return sorted(faces, key = get_bounding_box_top) + if order == 'bottom-top': + return sorted(faces, key = get_bounding_box_top, reverse = True) + if order == 'small-large': + return sorted(faces, key = get_bounding_box_area) + if order == 'large-small': + return sorted(faces, key = get_bounding_box_area, reverse = True) + if order == 'best-worst': + return sorted(faces, key = get_face_detector_score, reverse = True) + if order == 'worst-best': + return sorted(faces, key = get_face_detector_score) + return faces + + +def get_bounding_box_left(face : Face) -> float: + return face.bounding_box[0] + + +def get_bounding_box_top(face : Face) -> float: + return face.bounding_box[1] + + +def get_bounding_box_area(face : Face) -> float: + return (face.bounding_box[2] - face.bounding_box[0]) * (face.bounding_box[3] - face.bounding_box[1]) + + +def get_face_detector_score(face : Face) -> Score: + return face.score_set.get('detector') + + +def filter_faces_by_gender(faces : List[Face], gender : Gender) -> List[Face]: + filter_faces = [] + + for face in faces: + if face.gender == gender: + filter_faces.append(face) + return filter_faces + + +def filter_faces_by_age(faces : List[Face], face_selector_age_start : int, face_selector_age_end : int) -> List[Face]: + filter_faces = [] + age = range(face_selector_age_start, face_selector_age_end) + + for face in faces: + if set(face.age) & set(age): + filter_faces.append(face) + return filter_faces + + +def filter_faces_by_race(faces : List[Face], race : Race) -> List[Face]: + filter_faces = [] + + for face in faces: + if face.race == race: + filter_faces.append(face) + return filter_faces diff --git a/facefusion/face_store.py b/facefusion/face_store.py new file mode 100644 index 0000000000000000000000000000000000000000..bd7b2c5e6c07d620945726254052ecb09110ad3d --- /dev/null +++ b/facefusion/face_store.py @@ -0,0 +1,43 @@ +from typing import List, Optional + +from facefusion.hash_helper import create_hash +from facefusion.types import Face, FaceSet, FaceStore, VisionFrame + +FACE_STORE : FaceStore =\ +{ + 'static_faces': {}, + 'reference_faces': {} +} + + +def get_face_store() -> FaceStore: + return FACE_STORE + + +def get_static_faces(vision_frame : VisionFrame) -> Optional[List[Face]]: + vision_hash = create_hash(vision_frame.tobytes()) + return FACE_STORE.get('static_faces').get(vision_hash) + + +def set_static_faces(vision_frame : VisionFrame, faces : List[Face]) -> None: + vision_hash = create_hash(vision_frame.tobytes()) + if vision_hash: + FACE_STORE['static_faces'][vision_hash] = faces + + +def clear_static_faces() -> None: + FACE_STORE['static_faces'].clear() + + +def get_reference_faces() -> Optional[FaceSet]: + return FACE_STORE.get('reference_faces') + + +def append_reference_face(name : str, face : Face) -> None: + if name not in FACE_STORE.get('reference_faces'): + FACE_STORE['reference_faces'][name] = [] + FACE_STORE['reference_faces'][name].append(face) + + +def clear_reference_faces() -> None: + FACE_STORE['reference_faces'].clear() diff --git a/facefusion/ffmpeg.py b/facefusion/ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..bbfd026f263c88983febcf78d010559eff40b7d0 --- /dev/null +++ b/facefusion/ffmpeg.py @@ -0,0 +1,286 @@ +import os +import subprocess +import tempfile +from functools import partial +from typing import List, Optional, cast + +from tqdm import tqdm + +import facefusion.choices +from facefusion import ffmpeg_builder, logger, process_manager, state_manager, wording +from facefusion.filesystem import get_file_format, remove_file +from facefusion.temp_helper import get_temp_file_path, get_temp_frames_pattern +from facefusion.types import AudioBuffer, AudioEncoder, Commands, EncoderSet, Fps, UpdateProgress, VideoEncoder, VideoFormat +from facefusion.vision import detect_video_duration, detect_video_fps, predict_video_frame_total + + +def run_ffmpeg_with_progress(commands : Commands, update_progress : UpdateProgress) -> subprocess.Popen[bytes]: + log_level = state_manager.get_item('log_level') + commands.extend(ffmpeg_builder.set_progress()) + commands.extend(ffmpeg_builder.cast_stream()) + commands = ffmpeg_builder.run(commands) + process = subprocess.Popen(commands, stderr = subprocess.PIPE, stdout = subprocess.PIPE) + + while process_manager.is_processing(): + try: + + while __line__ := process.stdout.readline().decode().lower(): + if 'frame=' in __line__: + _, frame_number = __line__.split('frame=') + update_progress(int(frame_number)) + + if log_level == 'debug': + log_debug(process) + process.wait(timeout = 0.5) + except subprocess.TimeoutExpired: + continue + return process + + if process_manager.is_stopping(): + process.terminate() + return process + + +def update_progress(progress : tqdm, frame_number : int) -> None: + progress.update(frame_number - progress.n) + + +def run_ffmpeg(commands : Commands) -> subprocess.Popen[bytes]: + log_level = state_manager.get_item('log_level') + commands = ffmpeg_builder.run(commands) + process = subprocess.Popen(commands, stderr = subprocess.PIPE, stdout = subprocess.PIPE) + + while process_manager.is_processing(): + try: + if log_level == 'debug': + log_debug(process) + process.wait(timeout = 0.5) + except subprocess.TimeoutExpired: + continue + return process + + if process_manager.is_stopping(): + process.terminate() + return process + + +def open_ffmpeg(commands : Commands) -> subprocess.Popen[bytes]: + commands = ffmpeg_builder.run(commands) + return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE) + + +def log_debug(process : subprocess.Popen[bytes]) -> None: + _, stderr = process.communicate() + errors = stderr.decode().split(os.linesep) + + for error in errors: + if error.strip(): + logger.debug(error.strip(), __name__) + + +def get_available_encoder_set() -> EncoderSet: + available_encoder_set : EncoderSet =\ + { + 'audio': [], + 'video': [] + } + commands = ffmpeg_builder.chain( + ffmpeg_builder.get_encoders() + ) + process = run_ffmpeg(commands) + + while line := process.stdout.readline().decode().lower(): + if line.startswith(' a'): + audio_encoder = line.split()[1] + + if audio_encoder in facefusion.choices.output_audio_encoders: + index = facefusion.choices.output_audio_encoders.index(audio_encoder) #type:ignore[arg-type] + available_encoder_set['audio'].insert(index, audio_encoder) #type:ignore[arg-type] + if line.startswith(' v'): + video_encoder = line.split()[1] + + if video_encoder in facefusion.choices.output_video_encoders: + index = facefusion.choices.output_video_encoders.index(video_encoder) #type:ignore[arg-type] + available_encoder_set['video'].insert(index, video_encoder) #type:ignore[arg-type] + + return available_encoder_set + + +def extract_frames(target_path : str, temp_video_resolution : str, temp_video_fps : Fps, trim_frame_start : int, trim_frame_end : int) -> bool: + extract_frame_total = predict_video_frame_total(target_path, temp_video_fps, trim_frame_start, trim_frame_end) + temp_frames_pattern = get_temp_frames_pattern(target_path, '%08d') + commands = ffmpeg_builder.chain( + ffmpeg_builder.set_input(target_path), + ffmpeg_builder.set_media_resolution(temp_video_resolution), + ffmpeg_builder.set_frame_quality(0), + ffmpeg_builder.select_frame_range(trim_frame_start, trim_frame_end, temp_video_fps), + ffmpeg_builder.prevent_frame_drop(), + ffmpeg_builder.set_output(temp_frames_pattern) + ) + + with tqdm(total = extract_frame_total, desc = wording.get('extracting'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: + process = run_ffmpeg_with_progress(commands, partial(update_progress, progress)) + return process.returncode == 0 + + +def copy_image(target_path : str, temp_image_resolution : str) -> bool: + temp_image_path = get_temp_file_path(target_path) + commands = ffmpeg_builder.chain( + ffmpeg_builder.set_input(target_path), + ffmpeg_builder.set_media_resolution(temp_image_resolution), + ffmpeg_builder.set_image_quality(target_path, 100), + ffmpeg_builder.force_output(temp_image_path) + ) + return run_ffmpeg(commands).returncode == 0 + + +def finalize_image(target_path : str, output_path : str, output_image_resolution : str) -> bool: + output_image_quality = state_manager.get_item('output_image_quality') + temp_image_path = get_temp_file_path(target_path) + commands = ffmpeg_builder.chain( + ffmpeg_builder.set_input(temp_image_path), + ffmpeg_builder.set_media_resolution(output_image_resolution), + ffmpeg_builder.set_image_quality(target_path, output_image_quality), + ffmpeg_builder.force_output(output_path) + ) + return run_ffmpeg(commands).returncode == 0 + + +def read_audio_buffer(target_path : str, audio_sample_rate : int, audio_sample_size : int, audio_channel_total : int) -> Optional[AudioBuffer]: + commands = ffmpeg_builder.chain( + ffmpeg_builder.set_input(target_path), + ffmpeg_builder.ignore_video_stream(), + ffmpeg_builder.set_audio_sample_rate(audio_sample_rate), + ffmpeg_builder.set_audio_sample_size(audio_sample_size), + ffmpeg_builder.set_audio_channel_total(audio_channel_total), + ffmpeg_builder.cast_stream() + ) + + process = open_ffmpeg(commands) + audio_buffer, _ = process.communicate() + if process.returncode == 0: + return audio_buffer + return None + + +def restore_audio(target_path : str, output_path : str, trim_frame_start : int, trim_frame_end : int) -> bool: + output_audio_encoder = state_manager.get_item('output_audio_encoder') + output_audio_quality = state_manager.get_item('output_audio_quality') + output_audio_volume = state_manager.get_item('output_audio_volume') + target_video_fps = detect_video_fps(target_path) + temp_video_path = get_temp_file_path(target_path) + temp_video_format = cast(VideoFormat, get_file_format(temp_video_path)) + temp_video_duration = detect_video_duration(temp_video_path) + + output_audio_encoder = fix_audio_encoder(temp_video_format, output_audio_encoder) + commands = ffmpeg_builder.chain( + ffmpeg_builder.set_input(temp_video_path), + ffmpeg_builder.select_media_range(trim_frame_start, trim_frame_end, target_video_fps), + ffmpeg_builder.set_input(target_path), + ffmpeg_builder.copy_video_encoder(), + ffmpeg_builder.set_audio_encoder(output_audio_encoder), + ffmpeg_builder.set_audio_quality(output_audio_encoder, output_audio_quality), + ffmpeg_builder.set_audio_volume(output_audio_volume), + ffmpeg_builder.select_media_stream('0:v:0'), + ffmpeg_builder.select_media_stream('1:a:0'), + ffmpeg_builder.set_video_duration(temp_video_duration), + ffmpeg_builder.force_output(output_path) + ) + return run_ffmpeg(commands).returncode == 0 + + +def replace_audio(target_path : str, audio_path : str, output_path : str) -> bool: + output_audio_encoder = state_manager.get_item('output_audio_encoder') + output_audio_quality = state_manager.get_item('output_audio_quality') + output_audio_volume = state_manager.get_item('output_audio_volume') + temp_video_path = get_temp_file_path(target_path) + temp_video_format = cast(VideoFormat, get_file_format(temp_video_path)) + temp_video_duration = detect_video_duration(temp_video_path) + + output_audio_encoder = fix_audio_encoder(temp_video_format, output_audio_encoder) + commands = ffmpeg_builder.chain( + ffmpeg_builder.set_input(temp_video_path), + ffmpeg_builder.set_input(audio_path), + ffmpeg_builder.copy_video_encoder(), + ffmpeg_builder.set_audio_encoder(output_audio_encoder), + ffmpeg_builder.set_audio_quality(output_audio_encoder, output_audio_quality), + ffmpeg_builder.set_audio_volume(output_audio_volume), + ffmpeg_builder.set_video_duration(temp_video_duration), + ffmpeg_builder.force_output(output_path) + ) + return run_ffmpeg(commands).returncode == 0 + + +def merge_video(target_path : str, temp_video_fps : Fps, output_video_resolution : str, output_video_fps : Fps, trim_frame_start : int, trim_frame_end : int) -> bool: + output_video_encoder = state_manager.get_item('output_video_encoder') + output_video_quality = state_manager.get_item('output_video_quality') + output_video_preset = state_manager.get_item('output_video_preset') + merge_frame_total = predict_video_frame_total(target_path, output_video_fps, trim_frame_start, trim_frame_end) + temp_video_path = get_temp_file_path(target_path) + temp_video_format = cast(VideoFormat, get_file_format(temp_video_path)) + temp_frames_pattern = get_temp_frames_pattern(target_path, '%08d') + + output_video_encoder = fix_video_encoder(temp_video_format, output_video_encoder) + commands = ffmpeg_builder.chain( + ffmpeg_builder.set_input_fps(temp_video_fps), + ffmpeg_builder.set_input(temp_frames_pattern), + ffmpeg_builder.set_media_resolution(output_video_resolution), + ffmpeg_builder.set_video_encoder(output_video_encoder), + ffmpeg_builder.set_video_quality(output_video_encoder, output_video_quality), + ffmpeg_builder.set_video_preset(output_video_encoder, output_video_preset), + ffmpeg_builder.set_video_fps(output_video_fps), + ffmpeg_builder.set_pixel_format(output_video_encoder), + ffmpeg_builder.set_video_colorspace('bt709'), + ffmpeg_builder.force_output(temp_video_path) + ) + + with tqdm(total = merge_frame_total, desc = wording.get('merging'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: + process = run_ffmpeg_with_progress(commands, partial(update_progress, progress)) + return process.returncode == 0 + + +def concat_video(output_path : str, temp_output_paths : List[str]) -> bool: + concat_video_path = tempfile.mktemp() + + with open(concat_video_path, 'w') as concat_video_file: + for temp_output_path in temp_output_paths: + concat_video_file.write('file \'' + os.path.abspath(temp_output_path) + '\'' + os.linesep) + concat_video_file.flush() + concat_video_file.close() + + output_path = os.path.abspath(output_path) + commands = ffmpeg_builder.chain( + ffmpeg_builder.unsafe_concat(), + ffmpeg_builder.set_input(concat_video_file.name), + ffmpeg_builder.copy_video_encoder(), + ffmpeg_builder.copy_audio_encoder(), + ffmpeg_builder.force_output(output_path) + ) + process = run_ffmpeg(commands) + process.communicate() + remove_file(concat_video_path) + return process.returncode == 0 + + +def fix_audio_encoder(video_format : VideoFormat, audio_encoder : AudioEncoder) -> AudioEncoder: + if video_format == 'avi' and audio_encoder == 'libopus': + return 'aac' + if video_format == 'm4v': + return 'aac' + if video_format == 'mov' and audio_encoder in [ 'flac', 'libopus' ]: + return 'aac' + if video_format == 'webm': + return 'libopus' + return audio_encoder + + +def fix_video_encoder(video_format : VideoFormat, video_encoder : VideoEncoder) -> VideoEncoder: + if video_format == 'm4v': + return 'libx264' + if video_format in [ 'mkv', 'mp4' ] and video_encoder == 'rawvideo': + return 'libx264' + if video_format == 'mov' and video_encoder == 'libvpx-vp9': + return 'libx264' + if video_format == 'webm': + return 'libvpx-vp9' + return video_encoder diff --git a/facefusion/ffmpeg_builder.py b/facefusion/ffmpeg_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb8e91f77b553d02628c173f090a7fb7d37bd94 --- /dev/null +++ b/facefusion/ffmpeg_builder.py @@ -0,0 +1,248 @@ +import itertools +import shutil +from typing import Optional + +import numpy + +from facefusion.filesystem import get_file_format +from facefusion.types import AudioEncoder, Commands, Duration, Fps, StreamMode, VideoEncoder, VideoPreset + + +def run(commands : Commands) -> Commands: + return [ shutil.which('ffmpeg'), '-loglevel', 'error' ] + commands + + +def chain(*commands : Commands) -> Commands: + return list(itertools.chain(*commands)) + + +def get_encoders() -> Commands: + return [ '-encoders' ] + + +def set_hardware_accelerator(value : str) -> Commands: + return [ '-hwaccel', value ] + + +def set_progress() -> Commands: + return [ '-progress' ] + + +def set_input(input_path : str) -> Commands: + return [ '-i', input_path ] + + +def set_input_fps(input_fps : Fps) -> Commands: + return [ '-r', str(input_fps)] + + +def set_output(output_path : str) -> Commands: + return [ output_path ] + + +def force_output(output_path : str) -> Commands: + return [ '-y', output_path ] + + +def cast_stream() -> Commands: + return [ '-' ] + + +def set_stream_mode(stream_mode : StreamMode) -> Commands: + if stream_mode == 'udp': + return [ '-f', 'mpegts' ] + if stream_mode == 'v4l2': + return [ '-f', 'v4l2' ] + return [] + + +def set_stream_quality(stream_quality : int) -> Commands: + return [ '-b:v', str(stream_quality) + 'k' ] + + +def unsafe_concat() -> Commands: + return [ '-f', 'concat', '-safe', '0' ] + + +def set_pixel_format(video_encoder : VideoEncoder) -> Commands: + if video_encoder == 'rawvideo': + return [ '-pix_fmt', 'rgb24' ] + return [ '-pix_fmt', 'yuv420p' ] + + +def set_frame_quality(frame_quality : int) -> Commands: + return [ '-q:v', str(frame_quality) ] + + +def select_frame_range(frame_start : int, frame_end : int, video_fps : Fps) -> Commands: + if isinstance(frame_start, int) and isinstance(frame_end, int): + return [ '-vf', 'trim=start_frame=' + str(frame_start) + ':end_frame=' + str(frame_end) + ',fps=' + str(video_fps) ] + if isinstance(frame_start, int): + return [ '-vf', 'trim=start_frame=' + str(frame_start) + ',fps=' + str(video_fps) ] + if isinstance(frame_end, int): + return [ '-vf', 'trim=end_frame=' + str(frame_end) + ',fps=' + str(video_fps) ] + return [ '-vf', 'fps=' + str(video_fps) ] + + +def prevent_frame_drop() -> Commands: + return [ '-vsync', '0' ] + + +def select_media_range(frame_start : int, frame_end : int, media_fps : Fps) -> Commands: + commands = [] + + if isinstance(frame_start, int): + commands.extend([ '-ss', str(frame_start / media_fps) ]) + if isinstance(frame_end, int): + commands.extend([ '-to', str(frame_end / media_fps) ]) + return commands + + +def select_media_stream(media_stream : str) -> Commands: + return [ '-map', media_stream ] + + +def set_media_resolution(video_resolution : str) -> Commands: + return [ '-s', video_resolution ] + + +def set_image_quality(image_path : str, image_quality : int) -> Commands: + if get_file_format(image_path) == 'webp': + image_compression = image_quality + else: + image_compression = round(31 - (image_quality * 0.31)) + return [ '-q:v', str(image_compression) ] + + +def set_audio_encoder(audio_codec : str) -> Commands: + return [ '-c:a', audio_codec ] + + +def copy_audio_encoder() -> Commands: + return set_audio_encoder('copy') + + +def set_audio_sample_rate(audio_sample_rate : int) -> Commands: + return [ '-ar', str(audio_sample_rate) ] + + +def set_audio_sample_size(audio_sample_size : int) -> Commands: + if audio_sample_size == 16: + return [ '-f', 's16le' ] + if audio_sample_size == 32: + return [ '-f', 's32le' ] + return [] + + +def set_audio_channel_total(audio_channel_total : int) -> Commands: + return [ '-ac', str(audio_channel_total) ] + + +def set_audio_quality(audio_encoder : AudioEncoder, audio_quality : int) -> Commands: + if audio_encoder == 'aac': + audio_compression = round(numpy.interp(audio_quality, [ 0, 100 ], [ 0.1, 2.0 ]), 1) + return [ '-q:a', str(audio_compression) ] + if audio_encoder == 'libmp3lame': + audio_compression = round(numpy.interp(audio_quality, [ 0, 100 ], [ 9, 0 ])) + return [ '-q:a', str(audio_compression) ] + if audio_encoder == 'libopus': + audio_bit_rate = round(numpy.interp(audio_quality, [ 0, 100 ], [ 64, 256 ])) + return [ '-b:a', str(audio_bit_rate) + 'k' ] + if audio_encoder == 'libvorbis': + audio_compression = round(numpy.interp(audio_quality, [ 0, 100 ], [ -1, 10 ]), 1) + return [ '-q:a', str(audio_compression) ] + return [] + + +def set_audio_volume(audio_volume : int) -> Commands: + return [ '-filter:a', 'volume=' + str(audio_volume / 100) ] + + +def set_video_encoder(video_encoder : str) -> Commands: + return [ '-c:v', video_encoder ] + + +def copy_video_encoder() -> Commands: + return set_video_encoder('copy') + + +def set_video_quality(video_encoder : VideoEncoder, video_quality : int) -> Commands: + if video_encoder in [ 'libx264', 'libx265' ]: + video_compression = round(numpy.interp(video_quality, [ 0, 100 ], [ 51, 0 ])) + return [ '-crf', str(video_compression) ] + if video_encoder == 'libvpx-vp9': + video_compression = round(numpy.interp(video_quality, [ 0, 100 ], [ 63, 0 ])) + return [ '-crf', str(video_compression) ] + if video_encoder in [ 'h264_nvenc', 'hevc_nvenc' ]: + video_compression = round(numpy.interp(video_quality, [ 0, 100 ], [ 51, 0 ])) + return [ '-cq', str(video_compression) ] + if video_encoder in [ 'h264_amf', 'hevc_amf' ]: + video_compression = round(numpy.interp(video_quality, [ 0, 100 ], [ 51, 0 ])) + return [ '-qp_i', str(video_compression), '-qp_p', str(video_compression), '-qp_b', str(video_compression) ] + if video_encoder in [ 'h264_qsv', 'hevc_qsv' ]: + video_compression = round(numpy.interp(video_quality, [ 0, 100 ], [ 51, 0 ])) + return [ '-qp', str(video_compression) ] + if video_encoder in [ 'h264_videotoolbox', 'hevc_videotoolbox' ]: + video_bit_rate = round(numpy.interp(video_quality, [ 0, 100 ], [ 1024, 50512 ])) + return [ '-b:v', str(video_bit_rate) + 'k' ] + return [] + + +def set_video_preset(video_encoder : VideoEncoder, video_preset : VideoPreset) -> Commands: + if video_encoder in [ 'libx264', 'libx265' ]: + return [ '-preset', video_preset ] + if video_encoder in [ 'h264_nvenc', 'hevc_nvenc' ]: + return [ '-preset', map_nvenc_preset(video_preset) ] + if video_encoder in [ 'h264_amf', 'hevc_amf' ]: + return [ '-quality', map_amf_preset(video_preset) ] + if video_encoder in [ 'h264_qsv', 'hevc_qsv' ]: + return [ '-preset', map_qsv_preset(video_preset) ] + return [] + + +def set_video_colorspace(video_colorspace : str) -> Commands: + return [ '-colorspace', video_colorspace ] + + +def set_video_fps(video_fps : Fps) -> Commands: + return [ '-vf', 'framerate=fps=' + str(video_fps) ] + + +def set_video_duration(video_duration : Duration) -> Commands: + return [ '-t', str(video_duration) ] + + +def capture_video() -> Commands: + return [ '-f', 'rawvideo', '-pix_fmt', 'rgb24' ] + + +def ignore_video_stream() -> Commands: + return [ '-vn' ] + + +def map_nvenc_preset(video_preset : VideoPreset) -> Optional[str]: + if video_preset in [ 'ultrafast', 'superfast', 'veryfast', 'faster', 'fast' ]: + return 'fast' + if video_preset == 'medium': + return 'medium' + if video_preset in [ 'slow', 'slower', 'veryslow' ]: + return 'slow' + return None + + +def map_amf_preset(video_preset : VideoPreset) -> Optional[str]: + if video_preset in [ 'ultrafast', 'superfast', 'veryfast' ]: + return 'speed' + if video_preset in [ 'faster', 'fast', 'medium' ]: + return 'balanced' + if video_preset in [ 'slow', 'slower', 'veryslow' ]: + return 'quality' + return None + + +def map_qsv_preset(video_preset : VideoPreset) -> Optional[str]: + if video_preset in [ 'ultrafast', 'superfast', 'veryfast' ]: + return 'veryfast' + if video_preset in [ 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow' ]: + return video_preset + return None diff --git a/facefusion/filesystem.py b/facefusion/filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..6556e4456bb481c3ced026a0254c2c6be225389c --- /dev/null +++ b/facefusion/filesystem.py @@ -0,0 +1,188 @@ +import glob +import os +import shutil +from typing import List, Optional + +import facefusion.choices + + +def get_file_size(file_path : str) -> int: + if is_file(file_path): + return os.path.getsize(file_path) + return 0 + + +def get_file_name(file_path : str) -> Optional[str]: + file_name, _ = os.path.splitext(os.path.basename(file_path)) + + if file_name: + return file_name + return None + + +def get_file_extension(file_path : str) -> Optional[str]: + _, file_extension = os.path.splitext(file_path) + + if file_extension: + return file_extension.lower() + return None + + +def get_file_format(file_path : str) -> Optional[str]: + file_extension = get_file_extension(file_path) + + if file_extension: + if file_extension == '.jpg': + return 'jpeg' + if file_extension == '.tif': + return 'tiff' + return file_extension.lstrip('.') + return None + + +def same_file_extension(first_file_path : str, second_file_path : str) -> bool: + first_file_extension = get_file_extension(first_file_path) + second_file_extension = get_file_extension(second_file_path) + + if first_file_extension and second_file_extension: + return get_file_extension(first_file_path) == get_file_extension(second_file_path) + return False + + +def is_file(file_path : str) -> bool: + if file_path: + return os.path.isfile(file_path) + return False + + +def is_audio(audio_path : str) -> bool: + return is_file(audio_path) and get_file_format(audio_path) in facefusion.choices.audio_formats + + +def has_audio(audio_paths : List[str]) -> bool: + if audio_paths: + return any(map(is_audio, audio_paths)) + return False + + +def are_audios(audio_paths : List[str]) -> bool: + if audio_paths: + return all(map(is_audio, audio_paths)) + return False + + +def is_image(image_path : str) -> bool: + return is_file(image_path) and get_file_format(image_path) in facefusion.choices.image_formats + + +def has_image(image_paths : List[str]) -> bool: + if image_paths: + return any(is_image(image_path) for image_path in image_paths) + return False + + +def are_images(image_paths : List[str]) -> bool: + if image_paths: + return all(map(is_image, image_paths)) + return False + + +def is_video(video_path : str) -> bool: + return is_file(video_path) and get_file_format(video_path) in facefusion.choices.video_formats + + +def has_video(video_paths : List[str]) -> bool: + if video_paths: + return any(map(is_video, video_paths)) + return False + + +def are_videos(video_paths : List[str]) -> bool: + if video_paths: + return any(map(is_video, video_paths)) + return False + + +def filter_audio_paths(paths : List[str]) -> List[str]: + if paths: + return [ path for path in paths if is_audio(path) ] + return [] + + +def filter_image_paths(paths : List[str]) -> List[str]: + if paths: + return [ path for path in paths if is_image(path) ] + return [] + + +def copy_file(file_path : str, move_path : str) -> bool: + if is_file(file_path): + shutil.copy(file_path, move_path) + return is_file(move_path) + return False + + +def move_file(file_path : str, move_path : str) -> bool: + if is_file(file_path): + shutil.move(file_path, move_path) + return not is_file(file_path) and is_file(move_path) + return False + + +def remove_file(file_path : str) -> bool: + if is_file(file_path): + os.remove(file_path) + return not is_file(file_path) + return False + + +def resolve_file_paths(directory_path : str) -> List[str]: + file_paths : List[str] = [] + + if is_directory(directory_path): + file_names_and_extensions = sorted(os.listdir(directory_path)) + + for file_name_and_extension in file_names_and_extensions: + if not file_name_and_extension.startswith(('.', '__')): + file_path = os.path.join(directory_path, file_name_and_extension) + file_paths.append(file_path) + + return file_paths + + +def resolve_file_pattern(file_pattern : str) -> List[str]: + if in_directory(file_pattern): + return sorted(glob.glob(file_pattern)) + return [] + + +def is_directory(directory_path : str) -> bool: + if directory_path: + return os.path.isdir(directory_path) + return False + + +def in_directory(file_path : str) -> bool: + if file_path: + directory_path = os.path.dirname(file_path) + if directory_path: + return not is_directory(file_path) and is_directory(directory_path) + return False + + +def create_directory(directory_path : str) -> bool: + if directory_path and not is_file(directory_path): + os.makedirs(directory_path, exist_ok = True) + return is_directory(directory_path) + return False + + +def remove_directory(directory_path : str) -> bool: + if is_directory(directory_path): + shutil.rmtree(directory_path, ignore_errors = True) + return not is_directory(directory_path) + return False + + +def resolve_relative_path(path : str) -> str: + return os.path.abspath(os.path.join(os.path.dirname(__file__), path)) diff --git a/facefusion/hash_helper.py b/facefusion/hash_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d84d8102a1d03a60a32c32038f4376b7c4c717 --- /dev/null +++ b/facefusion/hash_helper.py @@ -0,0 +1,32 @@ +import os +import zlib +from typing import Optional + +from facefusion.filesystem import get_file_name, is_file + + +def create_hash(content : bytes) -> str: + return format(zlib.crc32(content), '08x') + + +def validate_hash(validate_path : str) -> bool: + hash_path = get_hash_path(validate_path) + + if is_file(hash_path): + with open(hash_path) as hash_file: + hash_content = hash_file.read() + + with open(validate_path, 'rb') as validate_file: + validate_content = validate_file.read() + + return create_hash(validate_content) == hash_content + return False + + +def get_hash_path(validate_path : str) -> Optional[str]: + if is_file(validate_path): + validate_directory_path, file_name_and_extension = os.path.split(validate_path) + validate_file_name = get_file_name(file_name_and_extension) + + return os.path.join(validate_directory_path, validate_file_name + '.hash') + return None diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..f93d7e854894ac81dd5451b6ef41b6f8489662fd --- /dev/null +++ b/facefusion/inference_manager.py @@ -0,0 +1,74 @@ +import importlib +from time import sleep +from typing import List + +from onnxruntime import InferenceSession + +from facefusion import process_manager, state_manager +from facefusion.app_context import detect_app_context +from facefusion.execution import create_inference_session_providers +from facefusion.filesystem import is_file +from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet + +INFERENCE_POOL_SET : InferencePoolSet =\ +{ + 'cli': {}, + 'ui': {} +} + + +def get_inference_pool(module_name : str, model_names : List[str], model_source_set : DownloadSet) -> InferencePool: + while process_manager.is_checking(): + sleep(0.5) + execution_device_id = state_manager.get_item('execution_device_id') + execution_providers = resolve_execution_providers(module_name) + app_context = detect_app_context() + inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers) + + if app_context == 'cli' and INFERENCE_POOL_SET.get('ui').get(inference_context): + INFERENCE_POOL_SET['cli'][inference_context] = INFERENCE_POOL_SET.get('ui').get(inference_context) + if app_context == 'ui' and INFERENCE_POOL_SET.get('cli').get(inference_context): + INFERENCE_POOL_SET['ui'][inference_context] = INFERENCE_POOL_SET.get('cli').get(inference_context) + if not INFERENCE_POOL_SET.get(app_context).get(inference_context): + INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, execution_device_id, execution_providers) + + return INFERENCE_POOL_SET.get(app_context).get(inference_context) + + +def create_inference_pool(model_source_set : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool: + inference_pool : InferencePool = {} + + for model_name in model_source_set.keys(): + model_path = model_source_set.get(model_name).get('path') + if is_file(model_path): + inference_pool[model_name] = create_inference_session(model_path, execution_device_id, execution_providers) + + return inference_pool + + +def clear_inference_pool(module_name : str, model_names : List[str]) -> None: + execution_device_id = state_manager.get_item('execution_device_id') + execution_providers = resolve_execution_providers(module_name) + app_context = detect_app_context() + inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers) + + if INFERENCE_POOL_SET.get(app_context).get(inference_context): + del INFERENCE_POOL_SET[app_context][inference_context] + + +def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession: + inference_session_providers = create_inference_session_providers(execution_device_id, execution_providers) + return InferenceSession(model_path, providers = inference_session_providers) + + +def get_inference_context(module_name : str, model_names : List[str], execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str: + inference_context = '.'.join([ module_name ] + model_names + [ execution_device_id ] + list(execution_providers)) + return inference_context + + +def resolve_execution_providers(module_name : str) -> List[ExecutionProvider]: + module = importlib.import_module(module_name) + + if hasattr(module, 'resolve_execution_providers'): + return getattr(module, 'resolve_execution_providers')() + return state_manager.get_item('execution_providers') diff --git a/facefusion/installer.py b/facefusion/installer.py new file mode 100644 index 0000000000000000000000000000000000000000..a363b0639dd62f1bf14ec82f23b78e82e17cd7f4 --- /dev/null +++ b/facefusion/installer.py @@ -0,0 +1,96 @@ +import os +import shutil +import signal +import subprocess +import sys +from argparse import ArgumentParser, HelpFormatter +from functools import partial +from types import FrameType + +from facefusion import metadata, wording +from facefusion.common_helper import is_linux, is_windows + +ONNXRUNTIME_SET =\ +{ + 'default': ('onnxruntime', '1.22.0') +} +if is_windows() or is_linux(): + ONNXRUNTIME_SET['cuda'] = ('onnxruntime-gpu', '1.22.0') + ONNXRUNTIME_SET['openvino'] = ('onnxruntime-openvino', '1.22.0') +if is_windows(): + ONNXRUNTIME_SET['directml'] = ('onnxruntime-directml', '1.17.3') +if is_linux(): + ONNXRUNTIME_SET['rocm'] = ('onnxruntime-rocm', '1.21.0') + + +def cli() -> None: + signal.signal(signal.SIGINT, signal_exit) + program = ArgumentParser(formatter_class = partial(HelpFormatter, max_help_position = 50)) + program.add_argument('--onnxruntime', help = wording.get('help.install_dependency').format(dependency = 'onnxruntime'), choices = ONNXRUNTIME_SET.keys(), required = True) + program.add_argument('--skip-conda', help = wording.get('help.skip_conda'), action = 'store_true') + program.add_argument('-v', '--version', version = metadata.get('name') + ' ' + metadata.get('version'), action = 'version') + run(program) + + +def signal_exit(signum : int, frame : FrameType) -> None: + sys.exit(0) + + +def run(program : ArgumentParser) -> None: + args = program.parse_args() + has_conda = 'CONDA_PREFIX' in os.environ + onnxruntime_name, onnxruntime_version = ONNXRUNTIME_SET.get(args.onnxruntime) + + if not args.skip_conda and not has_conda: + sys.stdout.write(wording.get('conda_not_activated') + os.linesep) + sys.exit(1) + + with open('requirements.txt') as file: + + for line in file.readlines(): + __line__ = line.strip() + if not __line__.startswith('onnxruntime'): + subprocess.call([ shutil.which('pip'), 'install', line, '--force-reinstall' ]) + + if args.onnxruntime == 'rocm': + python_id = 'cp' + str(sys.version_info.major) + str(sys.version_info.minor) + + if python_id in [ 'cp310', 'cp312' ]: + wheel_name = 'onnxruntime_rocm-' + onnxruntime_version + '-' + python_id + '-' + python_id + '-linux_x86_64.whl' + wheel_url = 'https://repo.radeon.com/rocm/manylinux/rocm-rel-6.4/' + wheel_name + subprocess.call([ shutil.which('pip'), 'install', wheel_url, '--force-reinstall' ]) + else: + subprocess.call([ shutil.which('pip'), 'install', onnxruntime_name + '==' + onnxruntime_version, '--force-reinstall' ]) + + if args.onnxruntime == 'cuda' and has_conda: + library_paths = [] + + if is_linux(): + if os.getenv('LD_LIBRARY_PATH'): + library_paths = os.getenv('LD_LIBRARY_PATH').split(os.pathsep) + + python_id = 'python' + str(sys.version_info.major) + '.' + str(sys.version_info.minor) + library_paths.extend( + [ + os.path.join(os.getenv('CONDA_PREFIX'), 'lib'), + os.path.join(os.getenv('CONDA_PREFIX'), 'lib', python_id, 'site-packages', 'tensorrt_libs') + ]) + library_paths = list(dict.fromkeys([ library_path for library_path in library_paths if os.path.exists(library_path) ])) + + subprocess.call([ shutil.which('conda'), 'env', 'config', 'vars', 'set', 'LD_LIBRARY_PATH=' + os.pathsep.join(library_paths) ]) + + if is_windows(): + if os.getenv('PATH'): + library_paths = os.getenv('PATH').split(os.pathsep) + + library_paths.extend( + [ + os.path.join(os.getenv('CONDA_PREFIX'), 'Lib'), + os.path.join(os.getenv('CONDA_PREFIX'), 'Lib', 'site-packages', 'tensorrt_libs') + ]) + library_paths = list(dict.fromkeys([ library_path for library_path in library_paths if os.path.exists(library_path) ])) + + subprocess.call([ shutil.which('conda'), 'env', 'config', 'vars', 'set', 'PATH=' + os.pathsep.join(library_paths) ]) + + if args.onnxruntime == 'directml': + subprocess.call([ shutil.which('pip'), 'install', 'numpy==1.26.4', '--force-reinstall' ]) diff --git a/facefusion/jobs/__init__.py b/facefusion/jobs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/facefusion/jobs/job_helper.py b/facefusion/jobs/job_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..6e3139b152f500fbccc1d56a60f776879f031321 --- /dev/null +++ b/facefusion/jobs/job_helper.py @@ -0,0 +1,18 @@ +import os +from datetime import datetime +from typing import Optional + +from facefusion.filesystem import get_file_extension, get_file_name + + +def get_step_output_path(job_id : str, step_index : int, output_path : str) -> Optional[str]: + if output_path: + output_directory_path, _ = os.path.split(output_path) + output_file_name = get_file_name(_) + output_file_extension = get_file_extension(_) + return os.path.join(output_directory_path, output_file_name + '-' + job_id + '-' + str(step_index) + output_file_extension) + return None + + +def suggest_job_id(job_prefix : str = 'job') -> str: + return job_prefix + '-' + datetime.now().strftime('%Y-%m-%d-%H-%M-%S') diff --git a/facefusion/jobs/job_list.py b/facefusion/jobs/job_list.py new file mode 100644 index 0000000000000000000000000000000000000000..2003b96e176528956cc4c2f2aa84082a38b4e01f --- /dev/null +++ b/facefusion/jobs/job_list.py @@ -0,0 +1,34 @@ +from datetime import datetime +from typing import Optional, Tuple + +from facefusion.date_helper import describe_time_ago +from facefusion.jobs import job_manager +from facefusion.types import JobStatus, TableContents, TableHeaders + + +def compose_job_list(job_status : JobStatus) -> Tuple[TableHeaders, TableContents]: + jobs = job_manager.find_jobs(job_status) + job_headers : TableHeaders = [ 'job id', 'steps', 'date created', 'date updated', 'job status' ] + job_contents : TableContents = [] + + for index, job_id in enumerate(jobs): + if job_manager.validate_job(job_id): + job = jobs[job_id] + step_total = job_manager.count_step_total(job_id) + date_created = prepare_describe_datetime(job.get('date_created')) + date_updated = prepare_describe_datetime(job.get('date_updated')) + job_contents.append( + [ + job_id, + step_total, + date_created, + date_updated, + job_status + ]) + return job_headers, job_contents + + +def prepare_describe_datetime(date_time : Optional[str]) -> Optional[str]: + if date_time: + return describe_time_ago(datetime.fromisoformat(date_time)) + return None diff --git a/facefusion/jobs/job_manager.py b/facefusion/jobs/job_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..58f46e5b993f43afc8a6a500d67b6e2876617f8e --- /dev/null +++ b/facefusion/jobs/job_manager.py @@ -0,0 +1,265 @@ +import os +from copy import copy +from typing import List, Optional + +import facefusion.choices +from facefusion.date_helper import get_current_date_time +from facefusion.filesystem import create_directory, get_file_name, is_directory, is_file, move_file, remove_directory, remove_file, resolve_file_pattern +from facefusion.jobs.job_helper import get_step_output_path +from facefusion.json import read_json, write_json +from facefusion.types import Args, Job, JobSet, JobStatus, JobStep, JobStepStatus + +JOBS_PATH : Optional[str] = None + + +def init_jobs(jobs_path : str) -> bool: + global JOBS_PATH + + JOBS_PATH = jobs_path + job_status_paths = [ os.path.join(JOBS_PATH, job_status) for job_status in facefusion.choices.job_statuses ] + + for job_status_path in job_status_paths: + create_directory(job_status_path) + return all(is_directory(status_path) for status_path in job_status_paths) + + +def clear_jobs(jobs_path : str) -> bool: + return remove_directory(jobs_path) + + +def create_job(job_id : str) -> bool: + job : Job =\ + { + 'version': '1', + 'date_created': get_current_date_time().isoformat(), + 'date_updated': None, + 'steps': [] + } + + return create_job_file(job_id, job) + + +def submit_job(job_id : str) -> bool: + drafted_job_ids = find_job_ids('drafted') + steps = get_steps(job_id) + + if job_id in drafted_job_ids and steps: + return set_steps_status(job_id, 'queued') and move_job_file(job_id, 'queued') + return False + + +def submit_jobs(halt_on_error : bool) -> bool: + drafted_job_ids = find_job_ids('drafted') + has_error = False + + if drafted_job_ids: + for job_id in drafted_job_ids: + if not submit_job(job_id): + has_error = True + if halt_on_error: + return False + return not has_error + return False + + +def delete_job(job_id : str) -> bool: + return delete_job_file(job_id) + + +def delete_jobs(halt_on_error : bool) -> bool: + job_ids = find_job_ids('drafted') + find_job_ids('queued') + find_job_ids('failed') + find_job_ids('completed') + has_error = False + + if job_ids: + for job_id in job_ids: + if not delete_job(job_id): + has_error = True + if halt_on_error: + return False + return not has_error + return False + + +def find_jobs(job_status : JobStatus) -> JobSet: + job_ids = find_job_ids(job_status) + job_set : JobSet = {} + + for job_id in job_ids: + job_set[job_id] = read_job_file(job_id) + return job_set + + +def find_job_ids(job_status : JobStatus) -> List[str]: + job_pattern = os.path.join(JOBS_PATH, job_status, '*.json') + job_paths = resolve_file_pattern(job_pattern) + job_paths.sort(key = os.path.getmtime) + job_ids = [] + + for job_path in job_paths: + job_id = get_file_name(job_path) + job_ids.append(job_id) + return job_ids + + +def validate_job(job_id : str) -> bool: + job = read_job_file(job_id) + return bool(job and 'version' in job and 'date_created' in job and 'date_updated' in job and 'steps' in job) + + +def has_step(job_id : str, step_index : int) -> bool: + step_total = count_step_total(job_id) + return step_index in range(step_total) + + +def add_step(job_id : str, step_args : Args) -> bool: + job = read_job_file(job_id) + + if job: + job.get('steps').append( + { + 'args': step_args, + 'status': 'drafted' + }) + return update_job_file(job_id, job) + return False + + +def remix_step(job_id : str, step_index : int, step_args : Args) -> bool: + steps = get_steps(job_id) + step_args = copy(step_args) + + if step_index and step_index < 0: + step_index = count_step_total(job_id) - 1 + + if has_step(job_id, step_index): + output_path = steps[step_index].get('args').get('output_path') + step_args['target_path'] = get_step_output_path(job_id, step_index, output_path) + return add_step(job_id, step_args) + return False + + +def insert_step(job_id : str, step_index : int, step_args : Args) -> bool: + job = read_job_file(job_id) + step_args = copy(step_args) + + if step_index and step_index < 0: + step_index = count_step_total(job_id) - 1 + + if job and has_step(job_id, step_index): + job.get('steps').insert(step_index, + { + 'args': step_args, + 'status': 'drafted' + }) + return update_job_file(job_id, job) + return False + + +def remove_step(job_id : str, step_index : int) -> bool: + job = read_job_file(job_id) + + if step_index and step_index < 0: + step_index = count_step_total(job_id) - 1 + + if job and has_step(job_id, step_index): + job.get('steps').pop(step_index) + return update_job_file(job_id, job) + return False + + +def get_steps(job_id : str) -> List[JobStep]: + job = read_job_file(job_id) + + if job: + return job.get('steps') + return [] + + +def count_step_total(job_id : str) -> int: + steps = get_steps(job_id) + + if steps: + return len(steps) + return 0 + + +def set_step_status(job_id : str, step_index : int, step_status : JobStepStatus) -> bool: + job = read_job_file(job_id) + + if job: + steps = job.get('steps') + if has_step(job_id, step_index): + steps[step_index]['status'] = step_status + return update_job_file(job_id, job) + return False + + +def set_steps_status(job_id : str, step_status : JobStepStatus) -> bool: + job = read_job_file(job_id) + + if job: + for step in job.get('steps'): + step['status'] = step_status + return update_job_file(job_id, job) + return False + + +def read_job_file(job_id : str) -> Optional[Job]: + job_path = find_job_path(job_id) + return read_json(job_path) #type:ignore[return-value] + + +def create_job_file(job_id : str, job : Job) -> bool: + job_path = find_job_path(job_id) + + if not is_file(job_path): + job_create_path = suggest_job_path(job_id, 'drafted') + return write_json(job_create_path, job) #type:ignore[arg-type] + return False + + +def update_job_file(job_id : str, job : Job) -> bool: + job_path = find_job_path(job_id) + + if is_file(job_path): + job['date_updated'] = get_current_date_time().isoformat() + return write_json(job_path, job) #type:ignore[arg-type] + return False + + +def move_job_file(job_id : str, job_status : JobStatus) -> bool: + job_path = find_job_path(job_id) + job_move_path = suggest_job_path(job_id, job_status) + return move_file(job_path, job_move_path) + + +def delete_job_file(job_id : str) -> bool: + job_path = find_job_path(job_id) + return remove_file(job_path) + + +def suggest_job_path(job_id : str, job_status : JobStatus) -> Optional[str]: + job_file_name = get_job_file_name(job_id) + + if job_file_name: + return os.path.join(JOBS_PATH, job_status, job_file_name) + return None + + +def find_job_path(job_id : str) -> Optional[str]: + job_file_name = get_job_file_name(job_id) + + if job_file_name: + for job_status in facefusion.choices.job_statuses: + job_pattern = os.path.join(JOBS_PATH, job_status, job_file_name) + job_paths = resolve_file_pattern(job_pattern) + + for job_path in job_paths: + return job_path + return None + + +def get_job_file_name(job_id : str) -> Optional[str]: + if job_id: + return job_id + '.json' + return None diff --git a/facefusion/jobs/job_runner.py b/facefusion/jobs/job_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..23a0e38b9aff4eece580e8d256ce2c3416fca0f4 --- /dev/null +++ b/facefusion/jobs/job_runner.py @@ -0,0 +1,112 @@ +from facefusion.ffmpeg import concat_video +from facefusion.filesystem import are_images, are_videos, move_file, remove_file +from facefusion.jobs import job_helper, job_manager +from facefusion.types import JobOutputSet, JobStep, ProcessStep + + +def run_job(job_id : str, process_step : ProcessStep) -> bool: + queued_job_ids = job_manager.find_job_ids('queued') + + if job_id in queued_job_ids: + if run_steps(job_id, process_step) and finalize_steps(job_id): + clean_steps(job_id) + return job_manager.move_job_file(job_id, 'completed') + clean_steps(job_id) + job_manager.move_job_file(job_id, 'failed') + return False + + +def run_jobs(process_step : ProcessStep, halt_on_error : bool) -> bool: + queued_job_ids = job_manager.find_job_ids('queued') + has_error = False + + if queued_job_ids: + for job_id in queued_job_ids: + if not run_job(job_id, process_step): + has_error = True + if halt_on_error: + return False + return not has_error + return False + + +def retry_job(job_id : str, process_step : ProcessStep) -> bool: + failed_job_ids = job_manager.find_job_ids('failed') + + if job_id in failed_job_ids: + return job_manager.set_steps_status(job_id, 'queued') and job_manager.move_job_file(job_id, 'queued') and run_job(job_id, process_step) + return False + + +def retry_jobs(process_step : ProcessStep, halt_on_error : bool) -> bool: + failed_job_ids = job_manager.find_job_ids('failed') + has_error = False + + if failed_job_ids: + for job_id in failed_job_ids: + if not retry_job(job_id, process_step): + has_error = True + if halt_on_error: + return False + return not has_error + return False + + +def run_step(job_id : str, step_index : int, step : JobStep, process_step : ProcessStep) -> bool: + step_args = step.get('args') + + if job_manager.set_step_status(job_id, step_index, 'started') and process_step(job_id, step_index, step_args): + output_path = step_args.get('output_path') + step_output_path = job_helper.get_step_output_path(job_id, step_index, output_path) + + return move_file(output_path, step_output_path) and job_manager.set_step_status(job_id, step_index, 'completed') + job_manager.set_step_status(job_id, step_index, 'failed') + return False + + +def run_steps(job_id : str, process_step : ProcessStep) -> bool: + steps = job_manager.get_steps(job_id) + + if steps: + for index, step in enumerate(steps): + if not run_step(job_id, index, step, process_step): + return False + return True + return False + + +def finalize_steps(job_id : str) -> bool: + output_set = collect_output_set(job_id) + + for output_path, temp_output_paths in output_set.items(): + if are_videos(temp_output_paths): + if not concat_video(output_path, temp_output_paths): + return False + if are_images(temp_output_paths): + for temp_output_path in temp_output_paths: + if not move_file(temp_output_path, output_path): + return False + return True + + +def clean_steps(job_id: str) -> bool: + output_set = collect_output_set(job_id) + + for temp_output_paths in output_set.values(): + for temp_output_path in temp_output_paths: + if not remove_file(temp_output_path): + return False + return True + + +def collect_output_set(job_id : str) -> JobOutputSet: + steps = job_manager.get_steps(job_id) + job_output_set : JobOutputSet = {} + + for index, step in enumerate(steps): + output_path = step.get('args').get('output_path') + + if output_path: + step_output_path = job_manager.get_step_output_path(job_id, index, output_path) + job_output_set.setdefault(output_path, []).append(step_output_path) + return job_output_set diff --git a/facefusion/jobs/job_store.py b/facefusion/jobs/job_store.py new file mode 100644 index 0000000000000000000000000000000000000000..5a13ef124372622d43b6971597a39a57868d5698 --- /dev/null +++ b/facefusion/jobs/job_store.py @@ -0,0 +1,27 @@ +from typing import List + +from facefusion.types import JobStore + +JOB_STORE : JobStore =\ +{ + 'job_keys': [], + 'step_keys': [] +} + + +def get_job_keys() -> List[str]: + return JOB_STORE.get('job_keys') + + +def get_step_keys() -> List[str]: + return JOB_STORE.get('step_keys') + + +def register_job_keys(step_keys : List[str]) -> None: + for step_key in step_keys: + JOB_STORE['job_keys'].append(step_key) + + +def register_step_keys(job_keys : List[str]) -> None: + for job_key in job_keys: + JOB_STORE['step_keys'].append(job_key) diff --git a/facefusion/json.py b/facefusion/json.py new file mode 100644 index 0000000000000000000000000000000000000000..d688683fdf6c352fd544c1bbb4106a4123869854 --- /dev/null +++ b/facefusion/json.py @@ -0,0 +1,22 @@ +import json +from json import JSONDecodeError +from typing import Optional + +from facefusion.filesystem import is_file +from facefusion.types import Content + + +def read_json(json_path : str) -> Optional[Content]: + if is_file(json_path): + try: + with open(json_path) as json_file: + return json.load(json_file) + except JSONDecodeError: + pass + return None + + +def write_json(json_path : str, content : Content) -> bool: + with open(json_path, 'w') as json_file: + json.dump(content, json_file, indent = 4) + return is_file(json_path) diff --git a/facefusion/logger.py b/facefusion/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..d7fa779da7a38d1f2cb2a8322dd7eff6d2be877d --- /dev/null +++ b/facefusion/logger.py @@ -0,0 +1,48 @@ +from logging import Logger, basicConfig, getLogger + +import facefusion.choices +from facefusion.common_helper import get_first, get_last +from facefusion.types import LogLevel + + +def init(log_level : LogLevel) -> None: + basicConfig(format = '%(message)s') + get_package_logger().setLevel(facefusion.choices.log_level_set.get(log_level)) + + +def get_package_logger() -> Logger: + return getLogger('facefusion') + + +def debug(message : str, module_name : str) -> None: + get_package_logger().debug(create_message(message, module_name)) + + +def info(message : str, module_name : str) -> None: + get_package_logger().info(create_message(message, module_name)) + + +def warn(message : str, module_name : str) -> None: + get_package_logger().warning(create_message(message, module_name)) + + +def error(message : str, module_name : str) -> None: + get_package_logger().error(create_message(message, module_name)) + + +def create_message(message : str, module_name : str) -> str: + module_names = module_name.split('.') + first_module_name = get_first(module_names) + last_module_name = get_last(module_names) + + if first_module_name and last_module_name: + return '[' + first_module_name.upper() + '.' + last_module_name.upper() + '] ' + message + return message + + +def enable() -> None: + get_package_logger().disabled = False + + +def disable() -> None: + get_package_logger().disabled = True diff --git a/facefusion/memory.py b/facefusion/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..f4161ac060f7cd696df37d372ecdb2d1588fed1b --- /dev/null +++ b/facefusion/memory.py @@ -0,0 +1,21 @@ +from facefusion.common_helper import is_macos, is_windows + +if is_windows(): + import ctypes +else: + import resource + + +def limit_system_memory(system_memory_limit : int = 1) -> bool: + if is_macos(): + system_memory_limit = system_memory_limit * (1024 ** 6) + else: + system_memory_limit = system_memory_limit * (1024 ** 3) + try: + if is_windows(): + ctypes.windll.kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(system_memory_limit), ctypes.c_size_t(system_memory_limit)) #type:ignore[attr-defined] + else: + resource.setrlimit(resource.RLIMIT_DATA, (system_memory_limit, system_memory_limit)) + return True + except Exception: + return False diff --git a/facefusion/metadata.py b/facefusion/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..baacd55ca8c400553de22d55f724230ef0f3abfb --- /dev/null +++ b/facefusion/metadata.py @@ -0,0 +1,17 @@ +from typing import Optional + +METADATA =\ +{ + 'name': 'FaceFusion', + 'description': 'Industry leading face manipulation platform', + 'version': '3.3.2', + 'license': 'OpenRAIL-AS', + 'author': 'Henry Ruhs', + 'url': 'https://facefusion.io' +} + + +def get(key : str) -> Optional[str]: + if key in METADATA: + return METADATA.get(key) + return None diff --git a/facefusion/model_helper.py b/facefusion/model_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..0646cda77e9901084c378858ed1e5cc0d7c0ae98 --- /dev/null +++ b/facefusion/model_helper.py @@ -0,0 +1,11 @@ +from functools import lru_cache + +import onnx + +from facefusion.types import ModelInitializer + + +@lru_cache(maxsize = None) +def get_static_model_initializer(model_path : str) -> ModelInitializer: + model = onnx.load(model_path) + return onnx.numpy_helper.to_array(model.graph.initializer[-1]) diff --git a/facefusion/normalizer.py b/facefusion/normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..2c03dc9783b8285cd681fc8165a63bde8b278d69 --- /dev/null +++ b/facefusion/normalizer.py @@ -0,0 +1,21 @@ +from typing import List, Optional + +from facefusion.types import Fps, Padding + + +def normalize_padding(padding : Optional[List[int]]) -> Optional[Padding]: + if padding and len(padding) == 1: + return tuple([ padding[0] ] * 4) #type:ignore[return-value] + if padding and len(padding) == 2: + return tuple([ padding[0], padding[1], padding[0], padding[1] ]) #type:ignore[return-value] + if padding and len(padding) == 3: + return tuple([ padding[0], padding[1], padding[2], padding[1] ]) #type:ignore[return-value] + if padding and len(padding) == 4: + return tuple(padding) #type:ignore[return-value] + return None + + +def normalize_fps(fps : Optional[float]) -> Optional[Fps]: + if isinstance(fps, (int, float)): + return max(1.0, min(fps, 60.0)) + return None diff --git a/facefusion/process_manager.py b/facefusion/process_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..ce15014a4278ac50bf72934b4f82af42954618fc --- /dev/null +++ b/facefusion/process_manager.py @@ -0,0 +1,53 @@ +from typing import Generator, List + +from facefusion.types import ProcessState, QueuePayload + +PROCESS_STATE : ProcessState = 'pending' + + +def get_process_state() -> ProcessState: + return PROCESS_STATE + + +def set_process_state(process_state : ProcessState) -> None: + global PROCESS_STATE + + PROCESS_STATE = process_state + + +def is_checking() -> bool: + return get_process_state() == 'checking' + + +def is_processing() -> bool: + return get_process_state() == 'processing' + + +def is_stopping() -> bool: + return get_process_state() == 'stopping' + + +def is_pending() -> bool: + return get_process_state() == 'pending' + + +def check() -> None: + set_process_state('checking') + + +def start() -> None: + set_process_state('processing') + + +def stop() -> None: + set_process_state('stopping') + + +def end() -> None: + set_process_state('pending') + + +def manage(queue_payloads : List[QueuePayload]) -> Generator[QueuePayload, None, None]: + for query_payload in queue_payloads: + if is_processing(): + yield query_payload diff --git a/facefusion/processors/__init__.py b/facefusion/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/facefusion/processors/choices.py b/facefusion/processors/choices.py new file mode 100644 index 0000000000000000000000000000000000000000..e33a2b1cb3cc426e79b445f59ba2bd749800a6ba --- /dev/null +++ b/facefusion/processors/choices.py @@ -0,0 +1,224 @@ +from typing import List, Sequence + +from facefusion.common_helper import create_float_range, create_int_range +from facefusion.filesystem import get_file_name, resolve_file_paths, resolve_relative_path +from facefusion.processors.types import AgeModifierModel, DeepSwapperModel, ExpressionRestorerModel, FaceDebuggerItem, FaceEditorModel, FaceEnhancerModel, FaceSwapperModel, FaceSwapperSet, FrameColorizerModel, FrameEnhancerModel, LipSyncerModel + +age_modifier_models : List[AgeModifierModel] = [ 'styleganex_age' ] +deep_swapper_models : List[DeepSwapperModel] =\ +[ + 'druuzil/adam_levine_320', + 'druuzil/adrianne_palicki_384', + 'druuzil/agnetha_falskog_224', + 'druuzil/alan_ritchson_320', + 'druuzil/alicia_vikander_320', + 'druuzil/amber_midthunder_320', + 'druuzil/andras_arato_384', + 'druuzil/andrew_tate_320', + 'druuzil/angelina_jolie_384', + 'druuzil/anne_hathaway_320', + 'druuzil/anya_chalotra_320', + 'druuzil/arnold_schwarzenegger_320', + 'druuzil/benjamin_affleck_320', + 'druuzil/benjamin_stiller_384', + 'druuzil/bradley_pitt_224', + 'druuzil/brie_larson_384', + 'druuzil/bruce_campbell_384', + 'druuzil/bryan_cranston_320', + 'druuzil/catherine_blanchett_352', + 'druuzil/christian_bale_320', + 'druuzil/christopher_hemsworth_320', + 'druuzil/christoph_waltz_384', + 'druuzil/cillian_murphy_320', + 'druuzil/cobie_smulders_256', + 'druuzil/dwayne_johnson_384', + 'druuzil/edward_norton_320', + 'druuzil/elisabeth_shue_320', + 'druuzil/elizabeth_olsen_384', + 'druuzil/elon_musk_320', + 'druuzil/emily_blunt_320', + 'druuzil/emma_stone_384', + 'druuzil/emma_watson_320', + 'druuzil/erin_moriarty_384', + 'druuzil/eva_green_320', + 'druuzil/ewan_mcgregor_320', + 'druuzil/florence_pugh_320', + 'druuzil/freya_allan_320', + 'druuzil/gary_cole_224', + 'druuzil/gigi_hadid_224', + 'druuzil/harrison_ford_384', + 'druuzil/hayden_christensen_320', + 'druuzil/heath_ledger_320', + 'druuzil/henry_cavill_448', + 'druuzil/hugh_jackman_384', + 'druuzil/idris_elba_320', + 'druuzil/jack_nicholson_320', + 'druuzil/james_carrey_384', + 'druuzil/james_mcavoy_320', + 'druuzil/james_varney_320', + 'druuzil/jason_momoa_320', + 'druuzil/jason_statham_320', + 'druuzil/jennifer_connelly_384', + 'druuzil/jimmy_donaldson_320', + 'druuzil/jordan_peterson_384', + 'druuzil/karl_urban_224', + 'druuzil/kate_beckinsale_384', + 'druuzil/laurence_fishburne_384', + 'druuzil/lili_reinhart_320', + 'druuzil/luke_evans_384', + 'druuzil/mads_mikkelsen_384', + 'druuzil/mary_winstead_320', + 'druuzil/margaret_qualley_384', + 'druuzil/melina_juergens_320', + 'druuzil/michael_fassbender_320', + 'druuzil/michael_fox_320', + 'druuzil/millie_bobby_brown_320', + 'druuzil/morgan_freeman_320', + 'druuzil/patrick_stewart_224', + 'druuzil/rachel_weisz_384', + 'druuzil/rebecca_ferguson_320', + 'druuzil/scarlett_johansson_320', + 'druuzil/shannen_doherty_384', + 'druuzil/seth_macfarlane_384', + 'druuzil/thomas_cruise_320', + 'druuzil/thomas_hanks_384', + 'druuzil/william_murray_384', + 'druuzil/zoe_saldana_384', + 'edel/emma_roberts_224', + 'edel/ivanka_trump_224', + 'edel/lize_dzjabrailova_224', + 'edel/sidney_sweeney_224', + 'edel/winona_ryder_224', + 'iperov/alexandra_daddario_224', + 'iperov/alexei_navalny_224', + 'iperov/amber_heard_224', + 'iperov/dilraba_dilmurat_224', + 'iperov/elon_musk_224', + 'iperov/emilia_clarke_224', + 'iperov/emma_watson_224', + 'iperov/erin_moriarty_224', + 'iperov/jackie_chan_224', + 'iperov/james_carrey_224', + 'iperov/jason_statham_320', + 'iperov/keanu_reeves_320', + 'iperov/margot_robbie_224', + 'iperov/natalie_dormer_224', + 'iperov/nicolas_coppola_224', + 'iperov/robert_downey_224', + 'iperov/rowan_atkinson_224', + 'iperov/ryan_reynolds_224', + 'iperov/scarlett_johansson_224', + 'iperov/sylvester_stallone_224', + 'iperov/thomas_cruise_224', + 'iperov/thomas_holland_224', + 'iperov/vin_diesel_224', + 'iperov/vladimir_putin_224', + 'jen/angelica_trae_288', + 'jen/ella_freya_224', + 'jen/emma_myers_320', + 'jen/evie_pickerill_224', + 'jen/kang_hyewon_320', + 'jen/maddie_mead_224', + 'jen/nicole_turnbull_288', + 'mats/alica_schmidt_320', + 'mats/ashley_alexiss_224', + 'mats/billie_eilish_224', + 'mats/brie_larson_224', + 'mats/cara_delevingne_224', + 'mats/carolin_kebekus_224', + 'mats/chelsea_clinton_224', + 'mats/claire_boucher_224', + 'mats/corinna_kopf_224', + 'mats/florence_pugh_224', + 'mats/hillary_clinton_224', + 'mats/jenna_fischer_224', + 'mats/kim_jisoo_320', + 'mats/mica_suarez_320', + 'mats/shailene_woodley_224', + 'mats/shraddha_kapoor_320', + 'mats/yu_jimin_352', + 'rumateus/alison_brie_224', + 'rumateus/amber_heard_224', + 'rumateus/angelina_jolie_224', + 'rumateus/aubrey_plaza_224', + 'rumateus/bridget_regan_224', + 'rumateus/cobie_smulders_224', + 'rumateus/deborah_woll_224', + 'rumateus/dua_lipa_224', + 'rumateus/emma_stone_224', + 'rumateus/hailee_steinfeld_224', + 'rumateus/hilary_duff_224', + 'rumateus/jessica_alba_224', + 'rumateus/jessica_biel_224', + 'rumateus/john_cena_224', + 'rumateus/kim_kardashian_224', + 'rumateus/kristen_bell_224', + 'rumateus/lucy_liu_224', + 'rumateus/margot_robbie_224', + 'rumateus/megan_fox_224', + 'rumateus/meghan_markle_224', + 'rumateus/millie_bobby_brown_224', + 'rumateus/natalie_portman_224', + 'rumateus/nicki_minaj_224', + 'rumateus/olivia_wilde_224', + 'rumateus/shay_mitchell_224', + 'rumateus/sophie_turner_224', + 'rumateus/taylor_swift_224' +] + +custom_model_file_paths = resolve_file_paths(resolve_relative_path('../.assets/models/custom')) + +if custom_model_file_paths: + + for model_file_path in custom_model_file_paths: + model_id = '/'.join([ 'custom', get_file_name(model_file_path) ]) + deep_swapper_models.append(model_id) + +expression_restorer_models : List[ExpressionRestorerModel] = [ 'live_portrait' ] +face_debugger_items : List[FaceDebuggerItem] = [ 'bounding-box', 'face-landmark-5', 'face-landmark-5/68', 'face-landmark-68', 'face-landmark-68/5', 'face-mask', 'face-detector-score', 'face-landmarker-score', 'age', 'gender', 'race' ] +face_editor_models : List[FaceEditorModel] = [ 'live_portrait' ] +face_enhancer_models : List[FaceEnhancerModel] = [ 'codeformer', 'gfpgan_1.2', 'gfpgan_1.3', 'gfpgan_1.4', 'gpen_bfr_256', 'gpen_bfr_512', 'gpen_bfr_1024', 'gpen_bfr_2048', 'restoreformer_plus_plus' ] +face_swapper_set : FaceSwapperSet =\ +{ + 'blendswap_256': [ '256x256', '384x384', '512x512', '768x768', '1024x1024' ], + 'ghost_1_256': [ '256x256', '512x512', '768x768', '1024x1024' ], + 'ghost_2_256': [ '256x256', '512x512', '768x768', '1024x1024' ], + 'ghost_3_256': [ '256x256', '512x512', '768x768', '1024x1024' ], + 'hififace_unofficial_256': [ '256x256', '512x512', '768x768', '1024x1024' ], + 'hyperswap_1a_256': [ '256x256', '512x512', '768x768', '1024x1024' ], + 'hyperswap_1b_256': [ '256x256', '512x512', '768x768', '1024x1024' ], + 'hyperswap_1c_256': [ '256x256', '512x512', '768x768', '1024x1024' ], + 'inswapper_128': [ '128x128', '256x256', '384x384', '512x512', '768x768', '1024x1024' ], + 'inswapper_128_fp16': [ '128x128', '256x256', '384x384', '512x512', '768x768', '1024x1024' ], + 'simswap_256': [ '256x256', '512x512', '768x768', '1024x1024' ], + 'simswap_unofficial_512': [ '512x512', '768x768', '1024x1024' ], + 'uniface_256': [ '256x256', '512x512', '768x768', '1024x1024' ] +} +face_swapper_models : List[FaceSwapperModel] = list(face_swapper_set.keys()) +frame_colorizer_models : List[FrameColorizerModel] = [ 'ddcolor', 'ddcolor_artistic', 'deoldify', 'deoldify_artistic', 'deoldify_stable' ] +frame_colorizer_sizes : List[str] = [ '192x192', '256x256', '384x384', '512x512' ] +frame_enhancer_models : List[FrameEnhancerModel] = [ 'clear_reality_x4', 'lsdir_x4', 'nomos8k_sc_x4', 'real_esrgan_x2', 'real_esrgan_x2_fp16', 'real_esrgan_x4', 'real_esrgan_x4_fp16', 'real_esrgan_x8', 'real_esrgan_x8_fp16', 'real_hatgan_x4', 'real_web_photo_x4', 'realistic_rescaler_x4', 'remacri_x4', 'siax_x4', 'span_kendata_x4', 'swin2_sr_x4', 'ultra_sharp_x4', 'ultra_sharp_2_x4' ] +lip_syncer_models : List[LipSyncerModel] = [ 'edtalk_256', 'wav2lip_96', 'wav2lip_gan_96' ] + +age_modifier_direction_range : Sequence[int] = create_int_range(-100, 100, 1) +deep_swapper_morph_range : Sequence[int] = create_int_range(0, 100, 1) +expression_restorer_factor_range : Sequence[int] = create_int_range(0, 100, 1) +face_editor_eyebrow_direction_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_eye_gaze_horizontal_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_eye_gaze_vertical_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_eye_open_ratio_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_lip_open_ratio_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_mouth_grim_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_mouth_pout_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_mouth_purse_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_mouth_smile_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_mouth_position_horizontal_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_mouth_position_vertical_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_head_pitch_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_head_yaw_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_editor_head_roll_range : Sequence[float] = create_float_range(-1.0, 1.0, 0.05) +face_enhancer_blend_range : Sequence[int] = create_int_range(0, 100, 1) +face_enhancer_weight_range : Sequence[float] = create_float_range(0.0, 1.0, 0.05) +frame_colorizer_blend_range : Sequence[int] = create_int_range(0, 100, 1) +frame_enhancer_blend_range : Sequence[int] = create_int_range(0, 100, 1) +lip_syncer_weight_range : Sequence[float] = create_float_range(0.0, 1.0, 0.05) diff --git a/facefusion/processors/core.py b/facefusion/processors/core.py new file mode 100644 index 0000000000000000000000000000000000000000..545370f0429054347d74f9b2b7b3bb7a90a1646b --- /dev/null +++ b/facefusion/processors/core.py @@ -0,0 +1,99 @@ +import importlib +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from queue import Queue +from types import ModuleType +from typing import Any, List + +from tqdm import tqdm + +from facefusion import logger, state_manager, wording +from facefusion.exit_helper import hard_exit +from facefusion.types import ProcessFrames, QueuePayload + +PROCESSORS_METHODS =\ +[ + 'get_inference_pool', + 'clear_inference_pool', + 'register_args', + 'apply_args', + 'pre_check', + 'pre_process', + 'post_process', + 'get_reference_frame', + 'process_frame', + 'process_frames', + 'process_image', + 'process_video' +] + + +def load_processor_module(processor : str) -> Any: + try: + processor_module = importlib.import_module('facefusion.processors.modules.' + processor) + for method_name in PROCESSORS_METHODS: + if not hasattr(processor_module, method_name): + raise NotImplementedError + except ModuleNotFoundError as exception: + logger.error(wording.get('processor_not_loaded').format(processor = processor), __name__) + logger.debug(exception.msg, __name__) + hard_exit(1) + except NotImplementedError: + logger.error(wording.get('processor_not_implemented').format(processor = processor), __name__) + hard_exit(1) + return processor_module + + +def get_processors_modules(processors : List[str]) -> List[ModuleType]: + processor_modules = [] + + for processor in processors: + processor_module = load_processor_module(processor) + processor_modules.append(processor_module) + return processor_modules + + +def multi_process_frames(source_paths : List[str], temp_frame_paths : List[str], process_frames : ProcessFrames) -> None: + queue_payloads = create_queue_payloads(temp_frame_paths) + with tqdm(total = len(queue_payloads), desc = wording.get('processing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: + progress.set_postfix(execution_providers = state_manager.get_item('execution_providers')) + with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: + futures = [] + queue : Queue[QueuePayload] = create_queue(queue_payloads) + queue_per_future = max(len(queue_payloads) // state_manager.get_item('execution_thread_count') * state_manager.get_item('execution_queue_count'), 1) + + while not queue.empty(): + future = executor.submit(process_frames, source_paths, pick_queue(queue, queue_per_future), progress.update) + futures.append(future) + + for future_done in as_completed(futures): + future_done.result() + + +def create_queue(queue_payloads : List[QueuePayload]) -> Queue[QueuePayload]: + queue : Queue[QueuePayload] = Queue() + for queue_payload in queue_payloads: + queue.put(queue_payload) + return queue + + +def pick_queue(queue : Queue[QueuePayload], queue_per_future : int) -> List[QueuePayload]: + queues = [] + for _ in range(queue_per_future): + if not queue.empty(): + queues.append(queue.get()) + return queues + + +def create_queue_payloads(temp_frame_paths : List[str]) -> List[QueuePayload]: + queue_payloads = [] + temp_frame_paths = sorted(temp_frame_paths, key = os.path.basename) + + for frame_number, frame_path in enumerate(temp_frame_paths): + frame_payload : QueuePayload =\ + { + 'frame_number': frame_number, + 'frame_path': frame_path + } + queue_payloads.append(frame_payload) + return queue_payloads diff --git a/facefusion/processors/live_portrait.py b/facefusion/processors/live_portrait.py new file mode 100644 index 0000000000000000000000000000000000000000..5805bc58ccc37bdc5da0127117ee6f06e43f2065 --- /dev/null +++ b/facefusion/processors/live_portrait.py @@ -0,0 +1,101 @@ +from typing import Tuple + +import numpy +import scipy + +from facefusion.processors.types import LivePortraitExpression, LivePortraitPitch, LivePortraitRoll, LivePortraitRotation, LivePortraitYaw + +EXPRESSION_MIN = numpy.array( +[ + [ + [ -2.88067125e-02, -8.12731311e-02, -1.70541159e-03 ], + [ -4.88598682e-02, -3.32196616e-02, -1.67431499e-04 ], + [ -6.75425082e-02, -4.28681746e-02, -1.98950816e-04 ], + [ -7.23103955e-02, -3.28503326e-02, -7.31324719e-04 ], + [ -3.87073644e-02, -6.01546466e-02, -5.50269964e-04 ], + [ -6.38048723e-02, -2.23840728e-01, -7.13261834e-04 ], + [ -3.02710701e-02, -3.93195450e-02, -8.24086510e-06 ], + [ -2.95799859e-02, -5.39318882e-02, -1.74219604e-04 ], + [ -2.92359516e-02, -1.53050944e-02, -6.30460854e-05 ], + [ -5.56493877e-03, -2.34344602e-02, -1.26858242e-04 ], + [ -4.37593013e-02, -2.77768299e-02, -2.70503685e-02 ], + [ -1.76926646e-02, -1.91676542e-02, -1.15090821e-04 ], + [ -8.34268332e-03, -3.99775570e-03, -3.27481248e-05 ], + [ -3.40162888e-02, -2.81868968e-02, -1.96679524e-04 ], + [ -2.91855410e-02, -3.97511162e-02, -2.81230678e-05 ], + [ -1.50395725e-02, -2.49494594e-02, -9.42573533e-05 ], + [ -1.67938769e-02, -2.00953931e-02, -4.00750607e-04 ], + [ -1.86435618e-02, -2.48535164e-02, -2.74416432e-02 ], + [ -4.61211195e-03, -1.21660791e-02, -2.93173041e-04 ], + [ -4.10017073e-02, -7.43824020e-02, -4.42762971e-02 ], + [ -1.90370996e-02, -3.74363363e-02, -1.34740388e-02 ] + ] +]).astype(numpy.float32) +EXPRESSION_MAX = numpy.array( +[ + [ + [ 4.46682945e-02, 7.08772913e-02, 4.08344204e-04 ], + [ 2.14308221e-02, 6.15894832e-02, 4.85319615e-05 ], + [ 3.02363783e-02, 4.45043296e-02, 1.28298725e-05 ], + [ 3.05869691e-02, 3.79812494e-02, 6.57040102e-04 ], + [ 4.45670523e-02, 3.97259220e-02, 7.10966764e-04 ], + [ 9.43699256e-02, 9.85926315e-02, 2.02551950e-04 ], + [ 1.61131397e-02, 2.92906128e-02, 3.44733417e-06 ], + [ 5.23825921e-02, 1.07065082e-01, 6.61510974e-04 ], + [ 2.85718683e-03, 8.32320191e-03, 2.39314613e-04 ], + [ 2.57947259e-02, 1.60935968e-02, 2.41853559e-05 ], + [ 4.90833223e-02, 3.43903080e-02, 3.22353356e-02 ], + [ 1.44766076e-02, 3.39248963e-02, 1.42291479e-04 ], + [ 8.75749043e-04, 6.82212645e-03, 2.76097053e-05 ], + [ 1.86958015e-02, 3.84016186e-02, 7.33085908e-05 ], + [ 2.01714113e-02, 4.90544215e-02, 2.34028921e-05 ], + [ 2.46518422e-02, 3.29151377e-02, 3.48571630e-05 ], + [ 2.22457591e-02, 1.21796541e-02, 1.56396593e-04 ], + [ 1.72109623e-02, 3.01626958e-02, 1.36556877e-02 ], + [ 1.83460284e-02, 1.61141958e-02, 2.87440169e-04 ], + [ 3.57594155e-02, 1.80554688e-01, 2.75554154e-02 ], + [ 2.17450950e-02, 8.66811201e-02, 3.34241726e-02 ] + ] +]).astype(numpy.float32) + + +def limit_expression(expression : LivePortraitExpression) -> LivePortraitExpression: + return numpy.clip(expression, EXPRESSION_MIN, EXPRESSION_MAX) + + +def limit_euler_angles(target_pitch : LivePortraitPitch, target_yaw : LivePortraitYaw, target_roll : LivePortraitRoll, output_pitch : LivePortraitPitch, output_yaw : LivePortraitYaw, output_roll : LivePortraitRoll) -> Tuple[LivePortraitPitch, LivePortraitYaw, LivePortraitRoll]: + pitch_min, pitch_max, yaw_min, yaw_max, roll_min, roll_max = calc_euler_limits(target_pitch, target_yaw, target_roll) + output_pitch = numpy.clip(output_pitch, pitch_min, pitch_max) + output_yaw = numpy.clip(output_yaw, yaw_min, yaw_max) + output_roll = numpy.clip(output_roll, roll_min, roll_max) + return output_pitch, output_yaw, output_roll + + +def calc_euler_limits(pitch : LivePortraitPitch, yaw : LivePortraitYaw, roll : LivePortraitRoll) -> Tuple[float, float, float, float, float, float]: + pitch_min = -30.0 + pitch_max = 30.0 + yaw_min = -60.0 + yaw_max = 60.0 + roll_min = -20.0 + roll_max = 20.0 + + if pitch < 0: + pitch_min = min(pitch, pitch_min) + else: + pitch_max = max(pitch, pitch_max) + if yaw < 0: + yaw_min = min(yaw, yaw_min) + else: + yaw_max = max(yaw, yaw_max) + if roll < 0: + roll_min = min(roll, roll_min) + else: + roll_max = max(roll, roll_max) + + return pitch_min, pitch_max, yaw_min, yaw_max, roll_min, roll_max + + +def create_rotation(pitch : LivePortraitPitch, yaw : LivePortraitYaw, roll : LivePortraitRoll) -> LivePortraitRotation: + rotation = scipy.spatial.transform.Rotation.from_euler('xyz', [ pitch, yaw, roll ], degrees = True).as_matrix() + rotation = rotation.astype(numpy.float32) + return rotation diff --git a/facefusion/processors/modules/__init__.py b/facefusion/processors/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/facefusion/processors/modules/age_modifier.py b/facefusion/processors/modules/age_modifier.py new file mode 100644 index 0000000000000000000000000000000000000000..0ed6725968edf4568f941e5be2808e349c981251 --- /dev/null +++ b/facefusion/processors/modules/age_modifier.py @@ -0,0 +1,254 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List + +import cv2 +import numpy + +import facefusion.choices +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, video_manager, wording +from facefusion.common_helper import create_int_metavar +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.execution import has_execution_provider +from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.face_helper import merge_matrix, paste_back, scale_face_landmark_5, warp_face_by_face_landmark_5 +from facefusion.face_masker import create_box_mask, create_occlusion_mask +from facefusion.face_selector import find_similar_faces, sort_and_filter_faces +from facefusion.face_store import get_reference_faces +from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.types import AgeModifierDirection, AgeModifierInputs +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import thread_semaphore +from facefusion.types import ApplyStateItem, Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import match_frame_color, read_image, read_static_image, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'styleganex_age': + { + 'hashes': + { + 'age_modifier': + { + 'url': resolve_download_url('models-3.1.0', 'styleganex_age.hash'), + 'path': resolve_relative_path('../.assets/models/styleganex_age.hash') + } + }, + 'sources': + { + 'age_modifier': + { + 'url': resolve_download_url('models-3.1.0', 'styleganex_age.onnx'), + 'path': resolve_relative_path('../.assets/models/styleganex_age.onnx') + } + }, + 'templates': + { + 'target': 'ffhq_512', + 'target_with_background': 'styleganex_384' + }, + 'sizes': + { + 'target': (256, 256), + 'target_with_background': (384, 384) + } + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('age_modifier_model') ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('age_modifier_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + model_name = state_manager.get_item('age_modifier_model') + return create_static_model_set('full').get(model_name) + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--age-modifier-model', help = wording.get('help.age_modifier_model'), default = config.get_str_value('processors', 'age_modifier_model', 'styleganex_age'), choices = processors_choices.age_modifier_models) + group_processors.add_argument('--age-modifier-direction', help = wording.get('help.age_modifier_direction'), type = int, default = config.get_int_value('processors', 'age_modifier_direction', '0'), choices = processors_choices.age_modifier_direction_range, metavar = create_int_metavar(processors_choices.age_modifier_direction_range)) + facefusion.jobs.job_store.register_step_keys([ 'age_modifier_model', 'age_modifier_direction' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('age_modifier_model', args.get('age_modifier_model')) + apply_state_item('age_modifier_direction', args.get('age_modifier_direction')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def pre_process(mode : ProcessMode) -> bool: + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_masker.clear_inference_pool() + face_recognizer.clear_inference_pool() + + +def modify_age(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + model_templates = get_model_options().get('templates') + model_sizes = get_model_options().get('sizes') + face_landmark_5 = target_face.landmark_set.get('5/68').copy() + crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, face_landmark_5, model_templates.get('target'), model_sizes.get('target')) + extend_face_landmark_5 = scale_face_landmark_5(face_landmark_5, 0.875) + extend_vision_frame, extend_affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, extend_face_landmark_5, model_templates.get('target_with_background'), model_sizes.get('target_with_background')) + extend_vision_frame_raw = extend_vision_frame.copy() + box_mask = create_box_mask(extend_vision_frame, state_manager.get_item('face_mask_blur'), (0, 0, 0, 0)) + crop_masks =\ + [ + box_mask + ] + + if 'occlusion' in state_manager.get_item('face_mask_types'): + occlusion_mask = create_occlusion_mask(crop_vision_frame) + combined_matrix = merge_matrix([ extend_affine_matrix, cv2.invertAffineTransform(affine_matrix) ]) + occlusion_mask = cv2.warpAffine(occlusion_mask, combined_matrix, model_sizes.get('target_with_background')) + crop_masks.append(occlusion_mask) + + crop_vision_frame = prepare_vision_frame(crop_vision_frame) + extend_vision_frame = prepare_vision_frame(extend_vision_frame) + age_modifier_direction = numpy.array(numpy.interp(state_manager.get_item('age_modifier_direction'), [ -100, 100 ], [ 2.5, -2.5 ])).astype(numpy.float32) + extend_vision_frame = forward(crop_vision_frame, extend_vision_frame, age_modifier_direction) + extend_vision_frame = normalize_extend_frame(extend_vision_frame) + extend_vision_frame = match_frame_color(extend_vision_frame_raw, extend_vision_frame) + extend_affine_matrix *= (model_sizes.get('target')[0] * 4) / model_sizes.get('target_with_background')[0] + crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) + crop_mask = cv2.resize(crop_mask, (model_sizes.get('target')[0] * 4, model_sizes.get('target')[1] * 4)) + paste_vision_frame = paste_back(temp_vision_frame, extend_vision_frame, crop_mask, extend_affine_matrix) + return paste_vision_frame + + +def forward(crop_vision_frame : VisionFrame, extend_vision_frame : VisionFrame, age_modifier_direction : AgeModifierDirection) -> VisionFrame: + age_modifier = get_inference_pool().get('age_modifier') + age_modifier_inputs = {} + + if has_execution_provider('coreml'): + age_modifier.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ]) + + for age_modifier_input in age_modifier.get_inputs(): + if age_modifier_input.name == 'target': + age_modifier_inputs[age_modifier_input.name] = crop_vision_frame + if age_modifier_input.name == 'target_with_background': + age_modifier_inputs[age_modifier_input.name] = extend_vision_frame + if age_modifier_input.name == 'direction': + age_modifier_inputs[age_modifier_input.name] = age_modifier_direction + + with thread_semaphore(): + crop_vision_frame = age_modifier.run(None, age_modifier_inputs)[0][0] + + return crop_vision_frame + + +def prepare_vision_frame(vision_frame : VisionFrame) -> VisionFrame: + vision_frame = vision_frame[:, :, ::-1] / 255.0 + vision_frame = (vision_frame - 0.5) / 0.5 + vision_frame = numpy.expand_dims(vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) + return vision_frame + + +def normalize_extend_frame(extend_vision_frame : VisionFrame) -> VisionFrame: + model_sizes = get_model_options().get('sizes') + extend_vision_frame = numpy.clip(extend_vision_frame, -1, 1) + extend_vision_frame = (extend_vision_frame + 1) / 2 + extend_vision_frame = extend_vision_frame.transpose(1, 2, 0).clip(0, 255) + extend_vision_frame = (extend_vision_frame * 255.0) + extend_vision_frame = extend_vision_frame.astype(numpy.uint8)[:, :, ::-1] + extend_vision_frame = cv2.resize(extend_vision_frame, (model_sizes.get('target')[0] * 4, model_sizes.get('target')[1] * 4), interpolation = cv2.INTER_AREA) + return extend_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + return modify_age(target_face, temp_vision_frame) + + +def process_frame(inputs : AgeModifierInputs) -> VisionFrame: + reference_faces = inputs.get('reference_faces') + target_vision_frame = inputs.get('target_vision_frame') + many_faces = sort_and_filter_faces(get_many_faces([ target_vision_frame ])) + + if state_manager.get_item('face_selector_mode') == 'many': + if many_faces: + for target_face in many_faces: + target_vision_frame = modify_age(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'one': + target_face = get_one_face(many_faces) + if target_face: + target_vision_frame = modify_age(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'reference': + similar_faces = find_similar_faces(many_faces, reference_faces, state_manager.get_item('reference_face_distance')) + if similar_faces: + for similar_face in similar_faces: + target_vision_frame = modify_age(similar_face, target_vision_frame) + return target_vision_frame + + +def process_frames(source_path : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + + for queue_payload in process_manager.manage(queue_payloads): + target_vision_path = queue_payload['frame_path'] + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_path : str, target_path : str, output_path : str) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(None, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/deep_swapper.py b/facefusion/processors/modules/deep_swapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1b1b35efc8c4433aa95c8bee23fe4176efa1a524 --- /dev/null +++ b/facefusion/processors/modules/deep_swapper.py @@ -0,0 +1,464 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List, Tuple + +import cv2 +import numpy +from cv2.typing import Size + +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, video_manager, wording +from facefusion.common_helper import create_int_metavar +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url_by_provider +from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.face_helper import paste_back, warp_face_by_face_landmark_5 +from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask, create_region_mask +from facefusion.face_selector import find_similar_faces, sort_and_filter_faces +from facefusion.face_store import get_reference_faces +from facefusion.filesystem import get_file_name, in_directory, is_image, is_video, resolve_file_paths, resolve_relative_path, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.types import DeepSwapperInputs, DeepSwapperMorph +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import thread_semaphore +from facefusion.types import ApplyStateItem, Args, DownloadScope, Face, InferencePool, Mask, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import conditional_match_frame_color, read_image, read_static_image, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + model_config = [] + + if download_scope == 'full': + model_config.extend( + [ + ('druuzil', 'adam_levine_320'), + ('druuzil', 'adrianne_palicki_384'), + ('druuzil', 'agnetha_falskog_224'), + ('druuzil', 'alan_ritchson_320'), + ('druuzil', 'alicia_vikander_320'), + ('druuzil', 'amber_midthunder_320'), + ('druuzil', 'andras_arato_384'), + ('druuzil', 'andrew_tate_320'), + ('druuzil', 'angelina_jolie_384'), + ('druuzil', 'anne_hathaway_320'), + ('druuzil', 'anya_chalotra_320'), + ('druuzil', 'arnold_schwarzenegger_320'), + ('druuzil', 'benjamin_affleck_320'), + ('druuzil', 'benjamin_stiller_384'), + ('druuzil', 'bradley_pitt_224'), + ('druuzil', 'brie_larson_384'), + ('druuzil', 'bruce_campbell_384'), + ('druuzil', 'bryan_cranston_320'), + ('druuzil', 'catherine_blanchett_352'), + ('druuzil', 'christian_bale_320'), + ('druuzil', 'christopher_hemsworth_320'), + ('druuzil', 'christoph_waltz_384'), + ('druuzil', 'cillian_murphy_320'), + ('druuzil', 'cobie_smulders_256'), + ('druuzil', 'dwayne_johnson_384'), + ('druuzil', 'edward_norton_320'), + ('druuzil', 'elisabeth_shue_320'), + ('druuzil', 'elizabeth_olsen_384'), + ('druuzil', 'elon_musk_320'), + ('druuzil', 'emily_blunt_320'), + ('druuzil', 'emma_stone_384'), + ('druuzil', 'emma_watson_320'), + ('druuzil', 'erin_moriarty_384'), + ('druuzil', 'eva_green_320'), + ('druuzil', 'ewan_mcgregor_320'), + ('druuzil', 'florence_pugh_320'), + ('druuzil', 'freya_allan_320'), + ('druuzil', 'gary_cole_224'), + ('druuzil', 'gigi_hadid_224'), + ('druuzil', 'harrison_ford_384'), + ('druuzil', 'hayden_christensen_320'), + ('druuzil', 'heath_ledger_320'), + ('druuzil', 'henry_cavill_448'), + ('druuzil', 'hugh_jackman_384'), + ('druuzil', 'idris_elba_320'), + ('druuzil', 'jack_nicholson_320'), + ('druuzil', 'james_carrey_384'), + ('druuzil', 'james_mcavoy_320'), + ('druuzil', 'james_varney_320'), + ('druuzil', 'jason_momoa_320'), + ('druuzil', 'jason_statham_320'), + ('druuzil', 'jennifer_connelly_384'), + ('druuzil', 'jimmy_donaldson_320'), + ('druuzil', 'jordan_peterson_384'), + ('druuzil', 'karl_urban_224'), + ('druuzil', 'kate_beckinsale_384'), + ('druuzil', 'laurence_fishburne_384'), + ('druuzil', 'lili_reinhart_320'), + ('druuzil', 'luke_evans_384'), + ('druuzil', 'mads_mikkelsen_384'), + ('druuzil', 'mary_winstead_320'), + ('druuzil', 'margaret_qualley_384'), + ('druuzil', 'melina_juergens_320'), + ('druuzil', 'michael_fassbender_320'), + ('druuzil', 'michael_fox_320'), + ('druuzil', 'millie_bobby_brown_320'), + ('druuzil', 'morgan_freeman_320'), + ('druuzil', 'patrick_stewart_224'), + ('druuzil', 'rachel_weisz_384'), + ('druuzil', 'rebecca_ferguson_320'), + ('druuzil', 'scarlett_johansson_320'), + ('druuzil', 'shannen_doherty_384'), + ('druuzil', 'seth_macfarlane_384'), + ('druuzil', 'thomas_cruise_320'), + ('druuzil', 'thomas_hanks_384'), + ('druuzil', 'william_murray_384'), + ('druuzil', 'zoe_saldana_384'), + ('edel', 'emma_roberts_224'), + ('edel', 'ivanka_trump_224'), + ('edel', 'lize_dzjabrailova_224'), + ('edel', 'sidney_sweeney_224'), + ('edel', 'winona_ryder_224') + ]) + if download_scope in [ 'lite', 'full' ]: + model_config.extend( + [ + ('iperov', 'alexandra_daddario_224'), + ('iperov', 'alexei_navalny_224'), + ('iperov', 'amber_heard_224'), + ('iperov', 'dilraba_dilmurat_224'), + ('iperov', 'elon_musk_224'), + ('iperov', 'emilia_clarke_224'), + ('iperov', 'emma_watson_224'), + ('iperov', 'erin_moriarty_224'), + ('iperov', 'jackie_chan_224'), + ('iperov', 'james_carrey_224'), + ('iperov', 'jason_statham_320'), + ('iperov', 'keanu_reeves_320'), + ('iperov', 'margot_robbie_224'), + ('iperov', 'natalie_dormer_224'), + ('iperov', 'nicolas_coppola_224'), + ('iperov', 'robert_downey_224'), + ('iperov', 'rowan_atkinson_224'), + ('iperov', 'ryan_reynolds_224'), + ('iperov', 'scarlett_johansson_224'), + ('iperov', 'sylvester_stallone_224'), + ('iperov', 'thomas_cruise_224'), + ('iperov', 'thomas_holland_224'), + ('iperov', 'vin_diesel_224'), + ('iperov', 'vladimir_putin_224') + ]) + if download_scope == 'full': + model_config.extend( + [ + ('jen', 'angelica_trae_288'), + ('jen', 'ella_freya_224'), + ('jen', 'emma_myers_320'), + ('jen', 'evie_pickerill_224'), + ('jen', 'kang_hyewon_320'), + ('jen', 'maddie_mead_224'), + ('jen', 'nicole_turnbull_288'), + ('mats', 'alica_schmidt_320'), + ('mats', 'ashley_alexiss_224'), + ('mats', 'billie_eilish_224'), + ('mats', 'brie_larson_224'), + ('mats', 'cara_delevingne_224'), + ('mats', 'carolin_kebekus_224'), + ('mats', 'chelsea_clinton_224'), + ('mats', 'claire_boucher_224'), + ('mats', 'corinna_kopf_224'), + ('mats', 'florence_pugh_224'), + ('mats', 'hillary_clinton_224'), + ('mats', 'jenna_fischer_224'), + ('mats', 'kim_jisoo_320'), + ('mats', 'mica_suarez_320'), + ('mats', 'shailene_woodley_224'), + ('mats', 'shraddha_kapoor_320'), + ('mats', 'yu_jimin_352'), + ('rumateus', 'alison_brie_224'), + ('rumateus', 'amber_heard_224'), + ('rumateus', 'angelina_jolie_224'), + ('rumateus', 'aubrey_plaza_224'), + ('rumateus', 'bridget_regan_224'), + ('rumateus', 'cobie_smulders_224'), + ('rumateus', 'deborah_woll_224'), + ('rumateus', 'dua_lipa_224'), + ('rumateus', 'emma_stone_224'), + ('rumateus', 'hailee_steinfeld_224'), + ('rumateus', 'hilary_duff_224'), + ('rumateus', 'jessica_alba_224'), + ('rumateus', 'jessica_biel_224'), + ('rumateus', 'john_cena_224'), + ('rumateus', 'kim_kardashian_224'), + ('rumateus', 'kristen_bell_224'), + ('rumateus', 'lucy_liu_224'), + ('rumateus', 'margot_robbie_224'), + ('rumateus', 'megan_fox_224'), + ('rumateus', 'meghan_markle_224'), + ('rumateus', 'millie_bobby_brown_224'), + ('rumateus', 'natalie_portman_224'), + ('rumateus', 'nicki_minaj_224'), + ('rumateus', 'olivia_wilde_224'), + ('rumateus', 'shay_mitchell_224'), + ('rumateus', 'sophie_turner_224'), + ('rumateus', 'taylor_swift_224') + ]) + model_set : ModelSet = {} + + for model_scope, model_name in model_config: + model_id = '/'.join([ model_scope, model_name ]) + + model_set[model_id] =\ + { + 'hashes': + { + 'deep_swapper': + { + 'url': resolve_download_url_by_provider('huggingface', 'deepfacelive-models-' + model_scope, model_name + '.hash'), + 'path': resolve_relative_path('../.assets/models/' + model_scope + '/' + model_name + '.hash') + } + }, + 'sources': + { + 'deep_swapper': + { + 'url': resolve_download_url_by_provider('huggingface', 'deepfacelive-models-' + model_scope, model_name + '.dfm'), + 'path': resolve_relative_path('../.assets/models/' + model_scope + '/' + model_name + '.dfm') + } + }, + 'template': 'dfl_whole_face' + } + + custom_model_file_paths = resolve_file_paths(resolve_relative_path('../.assets/models/custom')) + + if custom_model_file_paths: + + for model_file_path in custom_model_file_paths: + model_id = '/'.join([ 'custom', get_file_name(model_file_path) ]) + + model_set[model_id] =\ + { + 'sources': + { + 'deep_swapper': + { + 'path': resolve_relative_path(model_file_path) + } + }, + 'template': 'dfl_whole_face' + } + + return model_set + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('deep_swapper_model') ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('deep_swapper_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + model_name = state_manager.get_item('deep_swapper_model') + return create_static_model_set('full').get(model_name) + + +def get_model_size() -> Size: + deep_swapper = get_inference_pool().get('deep_swapper') + + for deep_swapper_input in deep_swapper.get_inputs(): + if deep_swapper_input.name == 'in_face:0': + return deep_swapper_input.shape[1:3] + + return 0, 0 + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--deep-swapper-model', help = wording.get('help.deep_swapper_model'), default = config.get_str_value('processors', 'deep_swapper_model', 'iperov/elon_musk_224'), choices = processors_choices.deep_swapper_models) + group_processors.add_argument('--deep-swapper-morph', help = wording.get('help.deep_swapper_morph'), type = int, default = config.get_int_value('processors', 'deep_swapper_morph', '100'), choices = processors_choices.deep_swapper_morph_range, metavar = create_int_metavar(processors_choices.deep_swapper_morph_range)) + facefusion.jobs.job_store.register_step_keys([ 'deep_swapper_model', 'deep_swapper_morph' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('deep_swapper_model', args.get('deep_swapper_model')) + apply_state_item('deep_swapper_morph', args.get('deep_swapper_morph')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + if model_hash_set and model_source_set: + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + return True + + +def pre_process(mode : ProcessMode) -> bool: + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_masker.clear_inference_pool() + face_recognizer.clear_inference_pool() + + +def swap_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + model_template = get_model_options().get('template') + model_size = get_model_size() + crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size) + crop_vision_frame_raw = crop_vision_frame.copy() + box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), state_manager.get_item('face_mask_padding')) + crop_masks =\ + [ + box_mask + ] + + if 'occlusion' in state_manager.get_item('face_mask_types'): + occlusion_mask = create_occlusion_mask(crop_vision_frame) + crop_masks.append(occlusion_mask) + + crop_vision_frame = prepare_crop_frame(crop_vision_frame) + deep_swapper_morph = numpy.array([ numpy.interp(state_manager.get_item('deep_swapper_morph'), [ 0, 100 ], [ 0, 1 ]) ]).astype(numpy.float32) + crop_vision_frame, crop_source_mask, crop_target_mask = forward(crop_vision_frame, deep_swapper_morph) + crop_vision_frame = normalize_crop_frame(crop_vision_frame) + crop_vision_frame = conditional_match_frame_color(crop_vision_frame_raw, crop_vision_frame) + crop_masks.append(prepare_crop_mask(crop_source_mask, crop_target_mask)) + + if 'area' in state_manager.get_item('face_mask_types'): + face_landmark_68 = cv2.transform(target_face.landmark_set.get('68').reshape(1, -1, 2), affine_matrix).reshape(-1, 2) + area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas')) + crop_masks.append(area_mask) + + if 'region' in state_manager.get_item('face_mask_types'): + region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions')) + crop_masks.append(region_mask) + + crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) + paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix) + return paste_vision_frame + + +def forward(crop_vision_frame : VisionFrame, deep_swapper_morph : DeepSwapperMorph) -> Tuple[VisionFrame, Mask, Mask]: + deep_swapper = get_inference_pool().get('deep_swapper') + deep_swapper_inputs = {} + + for deep_swapper_input in deep_swapper.get_inputs(): + if deep_swapper_input.name == 'in_face:0': + deep_swapper_inputs[deep_swapper_input.name] = crop_vision_frame + if deep_swapper_input.name == 'morph_value:0': + deep_swapper_inputs[deep_swapper_input.name] = deep_swapper_morph + + with thread_semaphore(): + crop_target_mask, crop_vision_frame, crop_source_mask = deep_swapper.run(None, deep_swapper_inputs) + + return crop_vision_frame[0], crop_source_mask[0], crop_target_mask[0] + + +def has_morph_input() -> bool: + deep_swapper = get_inference_pool().get('deep_swapper') + + for deep_swapper_input in deep_swapper.get_inputs(): + if deep_swapper_input.name == 'morph_value:0': + return True + + return False + + +def prepare_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + crop_vision_frame = cv2.addWeighted(crop_vision_frame, 1.75, cv2.GaussianBlur(crop_vision_frame, (0, 0), 2), -0.75, 0) + crop_vision_frame = crop_vision_frame / 255.0 + crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0).astype(numpy.float32) + return crop_vision_frame + + +def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + crop_vision_frame = (crop_vision_frame * 255.0).clip(0, 255) + crop_vision_frame = crop_vision_frame.astype(numpy.uint8) + return crop_vision_frame + + +def prepare_crop_mask(crop_source_mask : Mask, crop_target_mask : Mask) -> Mask: + model_size = get_model_size() + blur_size = 6.25 + kernel_size = 3 + crop_mask = numpy.minimum.reduce([ crop_source_mask, crop_target_mask ]) + crop_mask = crop_mask.reshape(model_size).clip(0, 1) + crop_mask = cv2.erode(crop_mask, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)), iterations = 2) + crop_mask = cv2.GaussianBlur(crop_mask, (0, 0), blur_size) + return crop_mask + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + return swap_face(target_face, temp_vision_frame) + + +def process_frame(inputs : DeepSwapperInputs) -> VisionFrame: + reference_faces = inputs.get('reference_faces') + target_vision_frame = inputs.get('target_vision_frame') + many_faces = sort_and_filter_faces(get_many_faces([ target_vision_frame ])) + + if state_manager.get_item('face_selector_mode') == 'many': + if many_faces: + for target_face in many_faces: + target_vision_frame = swap_face(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'one': + target_face = get_one_face(many_faces) + if target_face: + target_vision_frame = swap_face(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'reference': + similar_faces = find_similar_faces(many_faces, reference_faces, state_manager.get_item('reference_face_distance')) + if similar_faces: + for similar_face in similar_faces: + target_vision_frame = swap_face(similar_face, target_vision_frame) + return target_vision_frame + + +def process_frames(source_path : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + + for queue_payload in process_manager.manage(queue_payloads): + target_vision_path = queue_payload['frame_path'] + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_path : str, target_path : str, output_path : str) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(None, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/expression_restorer.py b/facefusion/processors/modules/expression_restorer.py new file mode 100644 index 0000000000000000000000000000000000000000..12ebc87e6dec221e86097f0be9b4ed0588024db8 --- /dev/null +++ b/facefusion/processors/modules/expression_restorer.py @@ -0,0 +1,298 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List, Tuple + +import cv2 +import numpy + +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, video_manager, wording +from facefusion.common_helper import create_int_metavar +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.face_helper import paste_back, warp_face_by_face_landmark_5 +from facefusion.face_masker import create_box_mask, create_occlusion_mask +from facefusion.face_selector import find_similar_faces, sort_and_filter_faces +from facefusion.face_store import get_reference_faces +from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.live_portrait import create_rotation, limit_expression +from facefusion.processors.types import ExpressionRestorerInputs, LivePortraitExpression, LivePortraitFeatureVolume, LivePortraitMotionPoints, LivePortraitPitch, LivePortraitRoll, LivePortraitScale, LivePortraitTranslation, LivePortraitYaw +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import conditional_thread_semaphore, thread_semaphore +from facefusion.types import ApplyStateItem, Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import read_image, read_static_image, read_video_frame, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'live_portrait': + { + 'hashes': + { + 'feature_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_feature_extractor.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_feature_extractor.hash') + }, + 'motion_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_motion_extractor.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_motion_extractor.hash') + }, + 'generator': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_generator.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_generator.hash') + } + }, + 'sources': + { + 'feature_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_feature_extractor.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_feature_extractor.onnx') + }, + 'motion_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_motion_extractor.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_motion_extractor.onnx') + }, + 'generator': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_generator.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_generator.onnx') + } + }, + 'template': 'arcface_128', + 'size': (512, 512) + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('expression_restorer_model') ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('expression_restorer_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + model_name = state_manager.get_item('expression_restorer_model') + return create_static_model_set('full').get(model_name) + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--expression-restorer-model', help = wording.get('help.expression_restorer_model'), default = config.get_str_value('processors', 'expression_restorer_model', 'live_portrait'), choices = processors_choices.expression_restorer_models) + group_processors.add_argument('--expression-restorer-factor', help = wording.get('help.expression_restorer_factor'), type = int, default = config.get_int_value('processors', 'expression_restorer_factor', '80'), choices = processors_choices.expression_restorer_factor_range, metavar = create_int_metavar(processors_choices.expression_restorer_factor_range)) + facefusion.jobs.job_store.register_step_keys([ 'expression_restorer_model', 'expression_restorer_factor' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('expression_restorer_model', args.get('expression_restorer_model')) + apply_state_item('expression_restorer_factor', args.get('expression_restorer_factor')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def pre_process(mode : ProcessMode) -> bool: + if mode == 'stream': + logger.error(wording.get('stream_not_supported') + wording.get('exclamation_mark'), __name__) + return False + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_masker.clear_inference_pool() + face_recognizer.clear_inference_pool() + + +def restore_expression(source_vision_frame : VisionFrame, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + model_template = get_model_options().get('template') + model_size = get_model_options().get('size') + expression_restorer_factor = float(numpy.interp(float(state_manager.get_item('expression_restorer_factor')), [ 0, 100 ], [ 0, 1.2 ])) + source_vision_frame = cv2.resize(source_vision_frame, temp_vision_frame.shape[:2][::-1]) + source_crop_vision_frame, _ = warp_face_by_face_landmark_5(source_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size) + target_crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size) + box_mask = create_box_mask(target_crop_vision_frame, state_manager.get_item('face_mask_blur'), (0, 0, 0, 0)) + crop_masks =\ + [ + box_mask + ] + + if 'occlusion' in state_manager.get_item('face_mask_types'): + occlusion_mask = create_occlusion_mask(target_crop_vision_frame) + crop_masks.append(occlusion_mask) + + source_crop_vision_frame = prepare_crop_frame(source_crop_vision_frame) + target_crop_vision_frame = prepare_crop_frame(target_crop_vision_frame) + target_crop_vision_frame = apply_restore(source_crop_vision_frame, target_crop_vision_frame, expression_restorer_factor) + target_crop_vision_frame = normalize_crop_frame(target_crop_vision_frame) + crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) + temp_vision_frame = paste_back(temp_vision_frame, target_crop_vision_frame, crop_mask, affine_matrix) + return temp_vision_frame + + +def apply_restore(source_crop_vision_frame : VisionFrame, target_crop_vision_frame : VisionFrame, expression_restorer_factor : float) -> VisionFrame: + feature_volume = forward_extract_feature(target_crop_vision_frame) + source_expression = forward_extract_motion(source_crop_vision_frame)[5] + pitch, yaw, roll, scale, translation, target_expression, motion_points = forward_extract_motion(target_crop_vision_frame) + rotation = create_rotation(pitch, yaw, roll) + source_expression[:, [ 0, 4, 5, 8, 9 ]] = target_expression[:, [ 0, 4, 5, 8, 9 ]] + source_expression = source_expression * expression_restorer_factor + target_expression * (1 - expression_restorer_factor) + source_expression = limit_expression(source_expression) + source_motion_points = scale * (motion_points @ rotation.T + source_expression) + translation + target_motion_points = scale * (motion_points @ rotation.T + target_expression) + translation + crop_vision_frame = forward_generate_frame(feature_volume, source_motion_points, target_motion_points) + return crop_vision_frame + + +def forward_extract_feature(crop_vision_frame : VisionFrame) -> LivePortraitFeatureVolume: + feature_extractor = get_inference_pool().get('feature_extractor') + + with conditional_thread_semaphore(): + feature_volume = feature_extractor.run(None, + { + 'input': crop_vision_frame + })[0] + + return feature_volume + + +def forward_extract_motion(crop_vision_frame : VisionFrame) -> Tuple[LivePortraitPitch, LivePortraitYaw, LivePortraitRoll, LivePortraitScale, LivePortraitTranslation, LivePortraitExpression, LivePortraitMotionPoints]: + motion_extractor = get_inference_pool().get('motion_extractor') + + with conditional_thread_semaphore(): + pitch, yaw, roll, scale, translation, expression, motion_points = motion_extractor.run(None, + { + 'input': crop_vision_frame + }) + + return pitch, yaw, roll, scale, translation, expression, motion_points + + +def forward_generate_frame(feature_volume : LivePortraitFeatureVolume, source_motion_points : LivePortraitMotionPoints, target_motion_points : LivePortraitMotionPoints) -> VisionFrame: + generator = get_inference_pool().get('generator') + + with thread_semaphore(): + crop_vision_frame = generator.run(None, + { + 'feature_volume': feature_volume, + 'source': source_motion_points, + 'target': target_motion_points + })[0][0] + + return crop_vision_frame + + +def prepare_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + model_size = get_model_options().get('size') + prepare_size = (model_size[0] // 2, model_size[1] // 2) + crop_vision_frame = cv2.resize(crop_vision_frame, prepare_size, interpolation = cv2.INTER_AREA) + crop_vision_frame = crop_vision_frame[:, :, ::-1] / 255.0 + crop_vision_frame = numpy.expand_dims(crop_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) + return crop_vision_frame + + +def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + crop_vision_frame = crop_vision_frame.transpose(1, 2, 0).clip(0, 1) + crop_vision_frame = crop_vision_frame * 255.0 + crop_vision_frame = crop_vision_frame.astype(numpy.uint8)[:, :, ::-1] + return crop_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + pass + + +def process_frame(inputs : ExpressionRestorerInputs) -> VisionFrame: + reference_faces = inputs.get('reference_faces') + source_vision_frame = inputs.get('source_vision_frame') + target_vision_frame = inputs.get('target_vision_frame') + many_faces = sort_and_filter_faces(get_many_faces([ target_vision_frame ])) + + if state_manager.get_item('face_selector_mode') == 'many': + if many_faces: + for target_face in many_faces: + target_vision_frame = restore_expression(source_vision_frame, target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'one': + target_face = get_one_face(many_faces) + if target_face: + target_vision_frame = restore_expression(source_vision_frame, target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'reference': + similar_faces = find_similar_faces(many_faces, reference_faces, state_manager.get_item('reference_face_distance')) + if similar_faces: + for similar_face in similar_faces: + target_vision_frame = restore_expression(source_vision_frame, similar_face, target_vision_frame) + return target_vision_frame + + +def process_frames(source_path : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + + for queue_payload in process_manager.manage(queue_payloads): + frame_number = queue_payload.get('frame_number') + if state_manager.get_item('trim_frame_start'): + frame_number += state_manager.get_item('trim_frame_start') + source_vision_frame = read_video_frame(state_manager.get_item('target_path'), frame_number) + target_vision_path = queue_payload.get('frame_path') + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'source_vision_frame': source_vision_frame, + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_path : str, target_path : str, output_path : str) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + source_vision_frame = read_static_image(state_manager.get_item('target_path')) + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'source_vision_frame': source_vision_frame, + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(None, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/face_debugger.py b/facefusion/processors/modules/face_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..2402182fdac5c40ea9ba947a7e810b58b6443a38 --- /dev/null +++ b/facefusion/processors/modules/face_debugger.py @@ -0,0 +1,228 @@ +from argparse import ArgumentParser +from typing import List + +import cv2 +import numpy + +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, logger, process_manager, state_manager, video_manager, wording +from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.face_helper import warp_face_by_face_landmark_5 +from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask, create_region_mask +from facefusion.face_selector import find_similar_faces, sort_and_filter_faces +from facefusion.face_store import get_reference_faces +from facefusion.filesystem import in_directory, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.types import FaceDebuggerInputs +from facefusion.program_helper import find_argument_group +from facefusion.types import ApplyStateItem, Args, Face, InferencePool, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import read_image, read_static_image, write_image + + +def get_inference_pool() -> InferencePool: + pass + + +def clear_inference_pool() -> None: + pass + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--face-debugger-items', help = wording.get('help.face_debugger_items').format(choices = ', '.join(processors_choices.face_debugger_items)), default = config.get_str_list('processors', 'face_debugger_items', 'face-landmark-5/68 face-mask'), choices = processors_choices.face_debugger_items, nargs = '+', metavar = 'FACE_DEBUGGER_ITEMS') + facefusion.jobs.job_store.register_step_keys([ 'face_debugger_items' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('face_debugger_items', args.get('face_debugger_items')) + + +def pre_check() -> bool: + return True + + +def pre_process(mode : ProcessMode) -> bool: + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_masker.clear_inference_pool() + face_recognizer.clear_inference_pool() + + +def debug_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + primary_color = (0, 0, 255) + primary_light_color = (100, 100, 255) + secondary_color = (0, 255, 0) + tertiary_color = (255, 255, 0) + bounding_box = target_face.bounding_box.astype(numpy.int32) + temp_vision_frame = temp_vision_frame.copy() + has_face_landmark_5_fallback = numpy.array_equal(target_face.landmark_set.get('5'), target_face.landmark_set.get('5/68')) + has_face_landmark_68_fallback = numpy.array_equal(target_face.landmark_set.get('68'), target_face.landmark_set.get('68/5')) + face_debugger_items = state_manager.get_item('face_debugger_items') + + if 'bounding-box' in face_debugger_items: + x1, y1, x2, y2 = bounding_box + cv2.rectangle(temp_vision_frame, (x1, y1), (x2, y2), primary_color, 2) + + if target_face.angle == 0: + cv2.line(temp_vision_frame, (x1, y1), (x2, y1), primary_light_color, 3) + if target_face.angle == 180: + cv2.line(temp_vision_frame, (x1, y2), (x2, y2), primary_light_color, 3) + if target_face.angle == 90: + cv2.line(temp_vision_frame, (x2, y1), (x2, y2), primary_light_color, 3) + if target_face.angle == 270: + cv2.line(temp_vision_frame, (x1, y1), (x1, y2), primary_light_color, 3) + + if 'face-mask' in face_debugger_items: + crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), 'arcface_128', (512, 512)) + inverse_matrix = cv2.invertAffineTransform(affine_matrix) + temp_size = temp_vision_frame.shape[:2][::-1] + crop_masks = [] + + if 'box' in state_manager.get_item('face_mask_types'): + box_mask = create_box_mask(crop_vision_frame, 0, state_manager.get_item('face_mask_padding')) + crop_masks.append(box_mask) + + if 'occlusion' in state_manager.get_item('face_mask_types'): + occlusion_mask = create_occlusion_mask(crop_vision_frame) + crop_masks.append(occlusion_mask) + + if 'area' in state_manager.get_item('face_mask_types'): + face_landmark_68 = cv2.transform(target_face.landmark_set.get('68').reshape(1, -1, 2), affine_matrix).reshape(-1, 2) + area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas')) + crop_masks.append(area_mask) + + if 'region' in state_manager.get_item('face_mask_types'): + region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions')) + crop_masks.append(region_mask) + + crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) + crop_mask = (crop_mask * 255).astype(numpy.uint8) + inverse_vision_frame = cv2.warpAffine(crop_mask, inverse_matrix, temp_size) + inverse_vision_frame = cv2.threshold(inverse_vision_frame, 100, 255, cv2.THRESH_BINARY)[1] + inverse_vision_frame[inverse_vision_frame > 0] = 255 #type:ignore[operator] + inverse_contours = cv2.findContours(inverse_vision_frame, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)[0] + cv2.drawContours(temp_vision_frame, inverse_contours, -1, tertiary_color if has_face_landmark_5_fallback else secondary_color, 2) + + if 'face-landmark-5' in face_debugger_items and numpy.any(target_face.landmark_set.get('5')): + face_landmark_5 = target_face.landmark_set.get('5').astype(numpy.int32) + for index in range(face_landmark_5.shape[0]): + cv2.circle(temp_vision_frame, (face_landmark_5[index][0], face_landmark_5[index][1]), 3, primary_color, -1) + + if 'face-landmark-5/68' in face_debugger_items and numpy.any(target_face.landmark_set.get('5/68')): + face_landmark_5_68 = target_face.landmark_set.get('5/68').astype(numpy.int32) + for index in range(face_landmark_5_68.shape[0]): + cv2.circle(temp_vision_frame, (face_landmark_5_68[index][0], face_landmark_5_68[index][1]), 3, tertiary_color if has_face_landmark_5_fallback else secondary_color, -1) + + if 'face-landmark-68' in face_debugger_items and numpy.any(target_face.landmark_set.get('68')): + face_landmark_68 = target_face.landmark_set.get('68').astype(numpy.int32) + for index in range(face_landmark_68.shape[0]): + cv2.circle(temp_vision_frame, (face_landmark_68[index][0], face_landmark_68[index][1]), 3, tertiary_color if has_face_landmark_68_fallback else secondary_color, -1) + + if 'face-landmark-68/5' in face_debugger_items and numpy.any(target_face.landmark_set.get('68')): + face_landmark_68 = target_face.landmark_set.get('68/5').astype(numpy.int32) + for index in range(face_landmark_68.shape[0]): + cv2.circle(temp_vision_frame, (face_landmark_68[index][0], face_landmark_68[index][1]), 3, tertiary_color, -1) + + if bounding_box[3] - bounding_box[1] > 50 and bounding_box[2] - bounding_box[0] > 50: + top = bounding_box[1] + left = bounding_box[0] - 20 + + if 'face-detector-score' in face_debugger_items: + face_score_text = str(round(target_face.score_set.get('detector'), 2)) + top = top + 20 + cv2.putText(temp_vision_frame, face_score_text, (left, top), cv2.FONT_HERSHEY_SIMPLEX, 0.5, primary_color, 2) + + if 'face-landmarker-score' in face_debugger_items: + face_score_text = str(round(target_face.score_set.get('landmarker'), 2)) + top = top + 20 + cv2.putText(temp_vision_frame, face_score_text, (left, top), cv2.FONT_HERSHEY_SIMPLEX, 0.5, tertiary_color if has_face_landmark_5_fallback else secondary_color, 2) + + if 'age' in face_debugger_items: + face_age_text = str(target_face.age.start) + '-' + str(target_face.age.stop) + top = top + 20 + cv2.putText(temp_vision_frame, face_age_text, (left, top), cv2.FONT_HERSHEY_SIMPLEX, 0.5, primary_color, 2) + + if 'gender' in face_debugger_items: + face_gender_text = target_face.gender + top = top + 20 + cv2.putText(temp_vision_frame, face_gender_text, (left, top), cv2.FONT_HERSHEY_SIMPLEX, 0.5, primary_color, 2) + + if 'race' in face_debugger_items: + face_race_text = target_face.race + top = top + 20 + cv2.putText(temp_vision_frame, face_race_text, (left, top), cv2.FONT_HERSHEY_SIMPLEX, 0.5, primary_color, 2) + + return temp_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + pass + + +def process_frame(inputs : FaceDebuggerInputs) -> VisionFrame: + reference_faces = inputs.get('reference_faces') + target_vision_frame = inputs.get('target_vision_frame') + many_faces = sort_and_filter_faces(get_many_faces([ target_vision_frame ])) + + if state_manager.get_item('face_selector_mode') == 'many': + if many_faces: + for target_face in many_faces: + target_vision_frame = debug_face(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'one': + target_face = get_one_face(many_faces) + if target_face: + target_vision_frame = debug_face(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'reference': + similar_faces = find_similar_faces(many_faces, reference_faces, state_manager.get_item('reference_face_distance')) + if similar_faces: + for similar_face in similar_faces: + target_vision_frame = debug_face(similar_face, target_vision_frame) + return target_vision_frame + + +def process_frames(source_paths : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + + for queue_payload in process_manager.manage(queue_payloads): + target_vision_path = queue_payload['frame_path'] + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_paths : List[str], target_path : str, output_path : str) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(source_paths, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/face_editor.py b/facefusion/processors/modules/face_editor.py new file mode 100644 index 0000000000000000000000000000000000000000..f567b8759c3681bb594e77e0d090d47fc04149d0 --- /dev/null +++ b/facefusion/processors/modules/face_editor.py @@ -0,0 +1,533 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List, Tuple + +import cv2 +import numpy + +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, video_manager, wording +from facefusion.common_helper import create_float_metavar +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.face_helper import paste_back, scale_face_landmark_5, warp_face_by_face_landmark_5 +from facefusion.face_masker import create_box_mask +from facefusion.face_selector import find_similar_faces, sort_and_filter_faces +from facefusion.face_store import get_reference_faces +from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.live_portrait import create_rotation, limit_euler_angles, limit_expression +from facefusion.processors.types import FaceEditorInputs, LivePortraitExpression, LivePortraitFeatureVolume, LivePortraitMotionPoints, LivePortraitPitch, LivePortraitRoll, LivePortraitRotation, LivePortraitScale, LivePortraitTranslation, LivePortraitYaw +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import conditional_thread_semaphore, thread_semaphore +from facefusion.types import ApplyStateItem, Args, DownloadScope, Face, FaceLandmark68, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import read_image, read_static_image, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'live_portrait': + { + 'hashes': + { + 'feature_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_feature_extractor.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_feature_extractor.hash') + }, + 'motion_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_motion_extractor.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_motion_extractor.hash') + }, + 'eye_retargeter': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_eye_retargeter.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_eye_retargeter.hash') + }, + 'lip_retargeter': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_lip_retargeter.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_lip_retargeter.hash') + }, + 'stitcher': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_stitcher.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_stitcher.hash') + }, + 'generator': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_generator.hash'), + 'path': resolve_relative_path('../.assets/models/live_portrait_generator.hash') + } + }, + 'sources': + { + 'feature_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_feature_extractor.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_feature_extractor.onnx') + }, + 'motion_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_motion_extractor.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_motion_extractor.onnx') + }, + 'eye_retargeter': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_eye_retargeter.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_eye_retargeter.onnx') + }, + 'lip_retargeter': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_lip_retargeter.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_lip_retargeter.onnx') + }, + 'stitcher': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_stitcher.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_stitcher.onnx') + }, + 'generator': + { + 'url': resolve_download_url('models-3.0.0', 'live_portrait_generator.onnx'), + 'path': resolve_relative_path('../.assets/models/live_portrait_generator.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (512, 512) + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_editor_model') ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('face_editor_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + model_name = state_manager.get_item('face_editor_model') + return create_static_model_set('full').get(model_name) + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--face-editor-model', help = wording.get('help.face_editor_model'), default = config.get_str_value('processors', 'face_editor_model', 'live_portrait'), choices = processors_choices.face_editor_models) + group_processors.add_argument('--face-editor-eyebrow-direction', help = wording.get('help.face_editor_eyebrow_direction'), type = float, default = config.get_float_value('processors', 'face_editor_eyebrow_direction', '0'), choices = processors_choices.face_editor_eyebrow_direction_range, metavar = create_float_metavar(processors_choices.face_editor_eyebrow_direction_range)) + group_processors.add_argument('--face-editor-eye-gaze-horizontal', help = wording.get('help.face_editor_eye_gaze_horizontal'), type = float, default = config.get_float_value('processors', 'face_editor_eye_gaze_horizontal', '0'), choices = processors_choices.face_editor_eye_gaze_horizontal_range, metavar = create_float_metavar(processors_choices.face_editor_eye_gaze_horizontal_range)) + group_processors.add_argument('--face-editor-eye-gaze-vertical', help = wording.get('help.face_editor_eye_gaze_vertical'), type = float, default = config.get_float_value('processors', 'face_editor_eye_gaze_vertical', '0'), choices = processors_choices.face_editor_eye_gaze_vertical_range, metavar = create_float_metavar(processors_choices.face_editor_eye_gaze_vertical_range)) + group_processors.add_argument('--face-editor-eye-open-ratio', help = wording.get('help.face_editor_eye_open_ratio'), type = float, default = config.get_float_value('processors', 'face_editor_eye_open_ratio', '0'), choices = processors_choices.face_editor_eye_open_ratio_range, metavar = create_float_metavar(processors_choices.face_editor_eye_open_ratio_range)) + group_processors.add_argument('--face-editor-lip-open-ratio', help = wording.get('help.face_editor_lip_open_ratio'), type = float, default = config.get_float_value('processors', 'face_editor_lip_open_ratio', '0'), choices = processors_choices.face_editor_lip_open_ratio_range, metavar = create_float_metavar(processors_choices.face_editor_lip_open_ratio_range)) + group_processors.add_argument('--face-editor-mouth-grim', help = wording.get('help.face_editor_mouth_grim'), type = float, default = config.get_float_value('processors', 'face_editor_mouth_grim', '0'), choices = processors_choices.face_editor_mouth_grim_range, metavar = create_float_metavar(processors_choices.face_editor_mouth_grim_range)) + group_processors.add_argument('--face-editor-mouth-pout', help = wording.get('help.face_editor_mouth_pout'), type = float, default = config.get_float_value('processors', 'face_editor_mouth_pout', '0'), choices = processors_choices.face_editor_mouth_pout_range, metavar = create_float_metavar(processors_choices.face_editor_mouth_pout_range)) + group_processors.add_argument('--face-editor-mouth-purse', help = wording.get('help.face_editor_mouth_purse'), type = float, default = config.get_float_value('processors', 'face_editor_mouth_purse', '0'), choices = processors_choices.face_editor_mouth_purse_range, metavar = create_float_metavar(processors_choices.face_editor_mouth_purse_range)) + group_processors.add_argument('--face-editor-mouth-smile', help = wording.get('help.face_editor_mouth_smile'), type = float, default = config.get_float_value('processors', 'face_editor_mouth_smile', '0'), choices = processors_choices.face_editor_mouth_smile_range, metavar = create_float_metavar(processors_choices.face_editor_mouth_smile_range)) + group_processors.add_argument('--face-editor-mouth-position-horizontal', help = wording.get('help.face_editor_mouth_position_horizontal'), type = float, default = config.get_float_value('processors', 'face_editor_mouth_position_horizontal', '0'), choices = processors_choices.face_editor_mouth_position_horizontal_range, metavar = create_float_metavar(processors_choices.face_editor_mouth_position_horizontal_range)) + group_processors.add_argument('--face-editor-mouth-position-vertical', help = wording.get('help.face_editor_mouth_position_vertical'), type = float, default = config.get_float_value('processors', 'face_editor_mouth_position_vertical', '0'), choices = processors_choices.face_editor_mouth_position_vertical_range, metavar = create_float_metavar(processors_choices.face_editor_mouth_position_vertical_range)) + group_processors.add_argument('--face-editor-head-pitch', help = wording.get('help.face_editor_head_pitch'), type = float, default = config.get_float_value('processors', 'face_editor_head_pitch', '0'), choices = processors_choices.face_editor_head_pitch_range, metavar = create_float_metavar(processors_choices.face_editor_head_pitch_range)) + group_processors.add_argument('--face-editor-head-yaw', help = wording.get('help.face_editor_head_yaw'), type = float, default = config.get_float_value('processors', 'face_editor_head_yaw', '0'), choices = processors_choices.face_editor_head_yaw_range, metavar = create_float_metavar(processors_choices.face_editor_head_yaw_range)) + group_processors.add_argument('--face-editor-head-roll', help = wording.get('help.face_editor_head_roll'), type = float, default = config.get_float_value('processors', 'face_editor_head_roll', '0'), choices = processors_choices.face_editor_head_roll_range, metavar = create_float_metavar(processors_choices.face_editor_head_roll_range)) + facefusion.jobs.job_store.register_step_keys([ 'face_editor_model', 'face_editor_eyebrow_direction', 'face_editor_eye_gaze_horizontal', 'face_editor_eye_gaze_vertical', 'face_editor_eye_open_ratio', 'face_editor_lip_open_ratio', 'face_editor_mouth_grim', 'face_editor_mouth_pout', 'face_editor_mouth_purse', 'face_editor_mouth_smile', 'face_editor_mouth_position_horizontal', 'face_editor_mouth_position_vertical', 'face_editor_head_pitch', 'face_editor_head_yaw', 'face_editor_head_roll' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('face_editor_model', args.get('face_editor_model')) + apply_state_item('face_editor_eyebrow_direction', args.get('face_editor_eyebrow_direction')) + apply_state_item('face_editor_eye_gaze_horizontal', args.get('face_editor_eye_gaze_horizontal')) + apply_state_item('face_editor_eye_gaze_vertical', args.get('face_editor_eye_gaze_vertical')) + apply_state_item('face_editor_eye_open_ratio', args.get('face_editor_eye_open_ratio')) + apply_state_item('face_editor_lip_open_ratio', args.get('face_editor_lip_open_ratio')) + apply_state_item('face_editor_mouth_grim', args.get('face_editor_mouth_grim')) + apply_state_item('face_editor_mouth_pout', args.get('face_editor_mouth_pout')) + apply_state_item('face_editor_mouth_purse', args.get('face_editor_mouth_purse')) + apply_state_item('face_editor_mouth_smile', args.get('face_editor_mouth_smile')) + apply_state_item('face_editor_mouth_position_horizontal', args.get('face_editor_mouth_position_horizontal')) + apply_state_item('face_editor_mouth_position_vertical', args.get('face_editor_mouth_position_vertical')) + apply_state_item('face_editor_head_pitch', args.get('face_editor_head_pitch')) + apply_state_item('face_editor_head_yaw', args.get('face_editor_head_yaw')) + apply_state_item('face_editor_head_roll', args.get('face_editor_head_roll')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def pre_process(mode : ProcessMode) -> bool: + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_masker.clear_inference_pool() + face_recognizer.clear_inference_pool() + + +def edit_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + model_template = get_model_options().get('template') + model_size = get_model_options().get('size') + face_landmark_5 = scale_face_landmark_5(target_face.landmark_set.get('5/68'), 1.5) + crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, face_landmark_5, model_template, model_size) + box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), (0, 0, 0, 0)) + crop_vision_frame = prepare_crop_frame(crop_vision_frame) + crop_vision_frame = apply_edit(crop_vision_frame, target_face.landmark_set.get('68')) + crop_vision_frame = normalize_crop_frame(crop_vision_frame) + temp_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, box_mask, affine_matrix) + return temp_vision_frame + + +def apply_edit(crop_vision_frame : VisionFrame, face_landmark_68 : FaceLandmark68) -> VisionFrame: + feature_volume = forward_extract_feature(crop_vision_frame) + pitch, yaw, roll, scale, translation, expression, motion_points = forward_extract_motion(crop_vision_frame) + rotation = create_rotation(pitch, yaw, roll) + motion_points_target = scale * (motion_points @ rotation.T + expression) + translation + expression = edit_eye_gaze(expression) + expression = edit_mouth_grim(expression) + expression = edit_mouth_position(expression) + expression = edit_mouth_pout(expression) + expression = edit_mouth_purse(expression) + expression = edit_mouth_smile(expression) + expression = edit_eyebrow_direction(expression) + expression = limit_expression(expression) + rotation = edit_head_rotation(pitch, yaw, roll) + motion_points_source = motion_points @ rotation.T + motion_points_source += expression + motion_points_source *= scale + motion_points_source += translation + motion_points_source += edit_eye_open(motion_points_target, face_landmark_68) + motion_points_source += edit_lip_open(motion_points_target, face_landmark_68) + motion_points_source = forward_stitch_motion_points(motion_points_source, motion_points_target) + crop_vision_frame = forward_generate_frame(feature_volume, motion_points_source, motion_points_target) + return crop_vision_frame + + +def forward_extract_feature(crop_vision_frame : VisionFrame) -> LivePortraitFeatureVolume: + feature_extractor = get_inference_pool().get('feature_extractor') + + with conditional_thread_semaphore(): + feature_volume = feature_extractor.run(None, + { + 'input': crop_vision_frame + })[0] + + return feature_volume + + +def forward_extract_motion(crop_vision_frame : VisionFrame) -> Tuple[LivePortraitPitch, LivePortraitYaw, LivePortraitRoll, LivePortraitScale, LivePortraitTranslation, LivePortraitExpression, LivePortraitMotionPoints]: + motion_extractor = get_inference_pool().get('motion_extractor') + + with conditional_thread_semaphore(): + pitch, yaw, roll, scale, translation, expression, motion_points = motion_extractor.run(None, + { + 'input': crop_vision_frame + }) + + return pitch, yaw, roll, scale, translation, expression, motion_points + + +def forward_retarget_eye(eye_motion_points : LivePortraitMotionPoints) -> LivePortraitMotionPoints: + eye_retargeter = get_inference_pool().get('eye_retargeter') + + with conditional_thread_semaphore(): + eye_motion_points = eye_retargeter.run(None, + { + 'input': eye_motion_points + })[0] + + return eye_motion_points + + +def forward_retarget_lip(lip_motion_points : LivePortraitMotionPoints) -> LivePortraitMotionPoints: + lip_retargeter = get_inference_pool().get('lip_retargeter') + + with conditional_thread_semaphore(): + lip_motion_points = lip_retargeter.run(None, + { + 'input': lip_motion_points + })[0] + + return lip_motion_points + + +def forward_stitch_motion_points(source_motion_points : LivePortraitMotionPoints, target_motion_points : LivePortraitMotionPoints) -> LivePortraitMotionPoints: + stitcher = get_inference_pool().get('stitcher') + + with thread_semaphore(): + motion_points = stitcher.run(None, + { + 'source': source_motion_points, + 'target': target_motion_points + })[0] + + return motion_points + + +def forward_generate_frame(feature_volume : LivePortraitFeatureVolume, source_motion_points : LivePortraitMotionPoints, target_motion_points : LivePortraitMotionPoints) -> VisionFrame: + generator = get_inference_pool().get('generator') + + with thread_semaphore(): + crop_vision_frame = generator.run(None, + { + 'feature_volume': feature_volume, + 'source': source_motion_points, + 'target': target_motion_points + })[0][0] + + return crop_vision_frame + + +def edit_eyebrow_direction(expression : LivePortraitExpression) -> LivePortraitExpression: + face_editor_eyebrow = state_manager.get_item('face_editor_eyebrow_direction') + + if face_editor_eyebrow > 0: + expression[0, 1, 1] += numpy.interp(face_editor_eyebrow, [ -1, 1 ], [ -0.015, 0.015 ]) + expression[0, 2, 1] -= numpy.interp(face_editor_eyebrow, [ -1, 1 ], [ -0.020, 0.020 ]) + else: + expression[0, 1, 0] -= numpy.interp(face_editor_eyebrow, [ -1, 1 ], [ -0.015, 0.015 ]) + expression[0, 2, 0] += numpy.interp(face_editor_eyebrow, [ -1, 1 ], [ -0.020, 0.020 ]) + expression[0, 1, 1] += numpy.interp(face_editor_eyebrow, [ -1, 1 ], [ -0.005, 0.005 ]) + expression[0, 2, 1] -= numpy.interp(face_editor_eyebrow, [ -1, 1 ], [ -0.005, 0.005 ]) + return expression + + +def edit_eye_gaze(expression : LivePortraitExpression) -> LivePortraitExpression: + face_editor_eye_gaze_horizontal = state_manager.get_item('face_editor_eye_gaze_horizontal') + face_editor_eye_gaze_vertical = state_manager.get_item('face_editor_eye_gaze_vertical') + + if face_editor_eye_gaze_horizontal > 0: + expression[0, 11, 0] += numpy.interp(face_editor_eye_gaze_horizontal, [ -1, 1 ], [ -0.015, 0.015 ]) + expression[0, 15, 0] += numpy.interp(face_editor_eye_gaze_horizontal, [ -1, 1 ], [ -0.020, 0.020 ]) + else: + expression[0, 11, 0] += numpy.interp(face_editor_eye_gaze_horizontal, [ -1, 1 ], [ -0.020, 0.020 ]) + expression[0, 15, 0] += numpy.interp(face_editor_eye_gaze_horizontal, [ -1, 1 ], [ -0.015, 0.015 ]) + expression[0, 1, 1] += numpy.interp(face_editor_eye_gaze_vertical, [ -1, 1 ], [ -0.0025, 0.0025 ]) + expression[0, 2, 1] -= numpy.interp(face_editor_eye_gaze_vertical, [ -1, 1 ], [ -0.0025, 0.0025 ]) + expression[0, 11, 1] -= numpy.interp(face_editor_eye_gaze_vertical, [ -1, 1 ], [ -0.010, 0.010 ]) + expression[0, 13, 1] -= numpy.interp(face_editor_eye_gaze_vertical, [ -1, 1 ], [ -0.005, 0.005 ]) + expression[0, 15, 1] -= numpy.interp(face_editor_eye_gaze_vertical, [ -1, 1 ], [ -0.010, 0.010 ]) + expression[0, 16, 1] -= numpy.interp(face_editor_eye_gaze_vertical, [ -1, 1 ], [ -0.005, 0.005 ]) + return expression + + +def edit_eye_open(motion_points : LivePortraitMotionPoints, face_landmark_68 : FaceLandmark68) -> LivePortraitMotionPoints: + face_editor_eye_open_ratio = state_manager.get_item('face_editor_eye_open_ratio') + left_eye_ratio = calc_distance_ratio(face_landmark_68, 37, 40, 39, 36) + right_eye_ratio = calc_distance_ratio(face_landmark_68, 43, 46, 45, 42) + + if face_editor_eye_open_ratio < 0: + eye_motion_points = numpy.concatenate([ motion_points.ravel(), [ left_eye_ratio, right_eye_ratio, 0.0 ] ]) + else: + eye_motion_points = numpy.concatenate([ motion_points.ravel(), [ left_eye_ratio, right_eye_ratio, 0.6 ] ]) + eye_motion_points = eye_motion_points.reshape(1, -1).astype(numpy.float32) + eye_motion_points = forward_retarget_eye(eye_motion_points) * numpy.abs(face_editor_eye_open_ratio) + eye_motion_points = eye_motion_points.reshape(-1, 21, 3) + return eye_motion_points + + +def edit_lip_open(motion_points : LivePortraitMotionPoints, face_landmark_68 : FaceLandmark68) -> LivePortraitMotionPoints: + face_editor_lip_open_ratio = state_manager.get_item('face_editor_lip_open_ratio') + lip_ratio = calc_distance_ratio(face_landmark_68, 62, 66, 54, 48) + + if face_editor_lip_open_ratio < 0: + lip_motion_points = numpy.concatenate([ motion_points.ravel(), [ lip_ratio, 0.0 ] ]) + else: + lip_motion_points = numpy.concatenate([ motion_points.ravel(), [ lip_ratio, 1.0 ] ]) + lip_motion_points = lip_motion_points.reshape(1, -1).astype(numpy.float32) + lip_motion_points = forward_retarget_lip(lip_motion_points) * numpy.abs(face_editor_lip_open_ratio) + lip_motion_points = lip_motion_points.reshape(-1, 21, 3) + return lip_motion_points + + +def edit_mouth_grim(expression : LivePortraitExpression) -> LivePortraitExpression: + face_editor_mouth_grim = state_manager.get_item('face_editor_mouth_grim') + if face_editor_mouth_grim > 0: + expression[0, 17, 2] -= numpy.interp(face_editor_mouth_grim, [ -1, 1 ], [ -0.005, 0.005 ]) + expression[0, 19, 2] += numpy.interp(face_editor_mouth_grim, [ -1, 1 ], [ -0.01, 0.01 ]) + expression[0, 20, 1] -= numpy.interp(face_editor_mouth_grim, [ -1, 1 ], [ -0.06, 0.06 ]) + expression[0, 20, 2] -= numpy.interp(face_editor_mouth_grim, [ -1, 1 ], [ -0.03, 0.03 ]) + else: + expression[0, 19, 1] -= numpy.interp(face_editor_mouth_grim, [ -1, 1 ], [ -0.05, 0.05 ]) + expression[0, 19, 2] -= numpy.interp(face_editor_mouth_grim, [ -1, 1 ], [ -0.02, 0.02 ]) + expression[0, 20, 2] -= numpy.interp(face_editor_mouth_grim, [ -1, 1 ], [ -0.03, 0.03 ]) + return expression + + +def edit_mouth_position(expression : LivePortraitExpression) -> LivePortraitExpression: + face_editor_mouth_position_horizontal = state_manager.get_item('face_editor_mouth_position_horizontal') + face_editor_mouth_position_vertical = state_manager.get_item('face_editor_mouth_position_vertical') + expression[0, 19, 0] += numpy.interp(face_editor_mouth_position_horizontal, [ -1, 1 ], [ -0.05, 0.05 ]) + expression[0, 20, 0] += numpy.interp(face_editor_mouth_position_horizontal, [ -1, 1 ], [ -0.04, 0.04 ]) + if face_editor_mouth_position_vertical > 0: + expression[0, 19, 1] -= numpy.interp(face_editor_mouth_position_vertical, [ -1, 1 ], [ -0.04, 0.04 ]) + expression[0, 20, 1] -= numpy.interp(face_editor_mouth_position_vertical, [ -1, 1 ], [ -0.02, 0.02 ]) + else: + expression[0, 19, 1] -= numpy.interp(face_editor_mouth_position_vertical, [ -1, 1 ], [ -0.05, 0.05 ]) + expression[0, 20, 1] -= numpy.interp(face_editor_mouth_position_vertical, [ -1, 1 ], [ -0.04, 0.04 ]) + return expression + + +def edit_mouth_pout(expression : LivePortraitExpression) -> LivePortraitExpression: + face_editor_mouth_pout = state_manager.get_item('face_editor_mouth_pout') + if face_editor_mouth_pout > 0: + expression[0, 19, 1] -= numpy.interp(face_editor_mouth_pout, [ -1, 1 ], [ -0.022, 0.022 ]) + expression[0, 19, 2] += numpy.interp(face_editor_mouth_pout, [ -1, 1 ], [ -0.025, 0.025 ]) + expression[0, 20, 2] -= numpy.interp(face_editor_mouth_pout, [ -1, 1 ], [ -0.002, 0.002 ]) + else: + expression[0, 19, 1] += numpy.interp(face_editor_mouth_pout, [ -1, 1 ], [ -0.022, 0.022 ]) + expression[0, 19, 2] += numpy.interp(face_editor_mouth_pout, [ -1, 1 ], [ -0.025, 0.025 ]) + expression[0, 20, 2] -= numpy.interp(face_editor_mouth_pout, [ -1, 1 ], [ -0.002, 0.002 ]) + return expression + + +def edit_mouth_purse(expression : LivePortraitExpression) -> LivePortraitExpression: + face_editor_mouth_purse = state_manager.get_item('face_editor_mouth_purse') + if face_editor_mouth_purse > 0: + expression[0, 19, 1] -= numpy.interp(face_editor_mouth_purse, [ -1, 1 ], [ -0.04, 0.04 ]) + expression[0, 19, 2] -= numpy.interp(face_editor_mouth_purse, [ -1, 1 ], [ -0.02, 0.02 ]) + else: + expression[0, 14, 1] -= numpy.interp(face_editor_mouth_purse, [ -1, 1 ], [ -0.02, 0.02 ]) + expression[0, 17, 2] += numpy.interp(face_editor_mouth_purse, [ -1, 1 ], [ -0.01, 0.01 ]) + expression[0, 19, 2] -= numpy.interp(face_editor_mouth_purse, [ -1, 1 ], [ -0.015, 0.015 ]) + expression[0, 20, 2] -= numpy.interp(face_editor_mouth_purse, [ -1, 1 ], [ -0.002, 0.002 ]) + return expression + + +def edit_mouth_smile(expression : LivePortraitExpression) -> LivePortraitExpression: + face_editor_mouth_smile = state_manager.get_item('face_editor_mouth_smile') + if face_editor_mouth_smile > 0: + expression[0, 20, 1] -= numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.015, 0.015 ]) + expression[0, 14, 1] -= numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.025, 0.025 ]) + expression[0, 17, 1] += numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.01, 0.01 ]) + expression[0, 17, 2] += numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.004, 0.004 ]) + expression[0, 3, 1] -= numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.0045, 0.0045 ]) + expression[0, 7, 1] -= numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.0045, 0.0045 ]) + else: + expression[0, 14, 1] -= numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.02, 0.02 ]) + expression[0, 17, 1] += numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.003, 0.003 ]) + expression[0, 19, 1] += numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.02, 0.02 ]) + expression[0, 19, 2] -= numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.005, 0.005 ]) + expression[0, 20, 2] += numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.01, 0.01 ]) + expression[0, 3, 1] += numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.0045, 0.0045 ]) + expression[0, 7, 1] += numpy.interp(face_editor_mouth_smile, [ -1, 1 ], [ -0.0045, 0.0045 ]) + return expression + + +def edit_head_rotation(pitch : LivePortraitPitch, yaw : LivePortraitYaw, roll : LivePortraitRoll) -> LivePortraitRotation: + face_editor_head_pitch = state_manager.get_item('face_editor_head_pitch') + face_editor_head_yaw = state_manager.get_item('face_editor_head_yaw') + face_editor_head_roll = state_manager.get_item('face_editor_head_roll') + edit_pitch = pitch + float(numpy.interp(face_editor_head_pitch, [ -1, 1 ], [ 20, -20 ])) + edit_yaw = yaw + float(numpy.interp(face_editor_head_yaw, [ -1, 1 ], [ 60, -60 ])) + edit_roll = roll + float(numpy.interp(face_editor_head_roll, [ -1, 1 ], [ -15, 15 ])) + edit_pitch, edit_yaw, edit_roll = limit_euler_angles(pitch, yaw, roll, edit_pitch, edit_yaw, edit_roll) + rotation = create_rotation(edit_pitch, edit_yaw, edit_roll) + return rotation + + +def calc_distance_ratio(face_landmark_68 : FaceLandmark68, top_index : int, bottom_index : int, left_index : int, right_index : int) -> float: + vertical_direction = face_landmark_68[top_index] - face_landmark_68[bottom_index] + horizontal_direction = face_landmark_68[left_index] - face_landmark_68[right_index] + distance_ratio = float(numpy.linalg.norm(vertical_direction) / (numpy.linalg.norm(horizontal_direction) + 1e-6)) + return distance_ratio + + +def prepare_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + model_size = get_model_options().get('size') + prepare_size = (model_size[0] // 2, model_size[1] // 2) + crop_vision_frame = cv2.resize(crop_vision_frame, prepare_size, interpolation = cv2.INTER_AREA) + crop_vision_frame = crop_vision_frame[:, :, ::-1] / 255.0 + crop_vision_frame = numpy.expand_dims(crop_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) + return crop_vision_frame + + +def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + crop_vision_frame = crop_vision_frame.transpose(1, 2, 0).clip(0, 1) + crop_vision_frame = (crop_vision_frame * 255.0) + crop_vision_frame = crop_vision_frame.astype(numpy.uint8)[:, :, ::-1] + return crop_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + pass + + +def process_frame(inputs : FaceEditorInputs) -> VisionFrame: + reference_faces = inputs.get('reference_faces') + target_vision_frame = inputs.get('target_vision_frame') + many_faces = sort_and_filter_faces(get_many_faces([ target_vision_frame ])) + + if state_manager.get_item('face_selector_mode') == 'many': + if many_faces: + for target_face in many_faces: + target_vision_frame = edit_face(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'one': + target_face = get_one_face(many_faces) + if target_face: + target_vision_frame = edit_face(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'reference': + similar_faces = find_similar_faces(many_faces, reference_faces, state_manager.get_item('reference_face_distance')) + if similar_faces: + for similar_face in similar_faces: + target_vision_frame = edit_face(similar_face, target_vision_frame) + return target_vision_frame + + +def process_frames(source_path : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + + for queue_payload in process_manager.manage(queue_payloads): + target_vision_path = queue_payload['frame_path'] + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_path : str, target_path : str, output_path : str) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(None, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/face_enhancer.py b/facefusion/processors/modules/face_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..dce358e4c3b8f8a3aa8268e8dff04b422b64e966 --- /dev/null +++ b/facefusion/processors/modules/face_enhancer.py @@ -0,0 +1,414 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List + +import cv2 +import numpy + +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, video_manager, wording +from facefusion.common_helper import create_float_metavar, create_int_metavar +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.face_helper import paste_back, warp_face_by_face_landmark_5 +from facefusion.face_masker import create_box_mask, create_occlusion_mask +from facefusion.face_selector import find_similar_faces, sort_and_filter_faces +from facefusion.face_store import get_reference_faces +from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.types import FaceEnhancerInputs, FaceEnhancerWeight +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import thread_semaphore +from facefusion.types import ApplyStateItem, Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import read_image, read_static_image, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'codeformer': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'codeformer.hash'), + 'path': resolve_relative_path('../.assets/models/codeformer.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'codeformer.onnx'), + 'path': resolve_relative_path('../.assets/models/codeformer.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (512, 512) + }, + 'gfpgan_1.2': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gfpgan_1.2.hash'), + 'path': resolve_relative_path('../.assets/models/gfpgan_1.2.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gfpgan_1.2.onnx'), + 'path': resolve_relative_path('../.assets/models/gfpgan_1.2.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (512, 512) + }, + 'gfpgan_1.3': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gfpgan_1.3.hash'), + 'path': resolve_relative_path('../.assets/models/gfpgan_1.3.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gfpgan_1.3.onnx'), + 'path': resolve_relative_path('../.assets/models/gfpgan_1.3.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (512, 512) + }, + 'gfpgan_1.4': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gfpgan_1.4.hash'), + 'path': resolve_relative_path('../.assets/models/gfpgan_1.4.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gfpgan_1.4.onnx'), + 'path': resolve_relative_path('../.assets/models/gfpgan_1.4.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (512, 512) + }, + 'gpen_bfr_256': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gpen_bfr_256.hash'), + 'path': resolve_relative_path('../.assets/models/gpen_bfr_256.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gpen_bfr_256.onnx'), + 'path': resolve_relative_path('../.assets/models/gpen_bfr_256.onnx') + } + }, + 'template': 'arcface_128', + 'size': (256, 256) + }, + 'gpen_bfr_512': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gpen_bfr_512.hash'), + 'path': resolve_relative_path('../.assets/models/gpen_bfr_512.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gpen_bfr_512.onnx'), + 'path': resolve_relative_path('../.assets/models/gpen_bfr_512.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (512, 512) + }, + 'gpen_bfr_1024': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gpen_bfr_1024.hash'), + 'path': resolve_relative_path('../.assets/models/gpen_bfr_1024.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gpen_bfr_1024.onnx'), + 'path': resolve_relative_path('../.assets/models/gpen_bfr_1024.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (1024, 1024) + }, + 'gpen_bfr_2048': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gpen_bfr_2048.hash'), + 'path': resolve_relative_path('../.assets/models/gpen_bfr_2048.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'gpen_bfr_2048.onnx'), + 'path': resolve_relative_path('../.assets/models/gpen_bfr_2048.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (2048, 2048) + }, + 'restoreformer_plus_plus': + { + 'hashes': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'restoreformer_plus_plus.hash'), + 'path': resolve_relative_path('../.assets/models/restoreformer_plus_plus.hash') + } + }, + 'sources': + { + 'face_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'restoreformer_plus_plus.onnx'), + 'path': resolve_relative_path('../.assets/models/restoreformer_plus_plus.onnx') + } + }, + 'template': 'ffhq_512', + 'size': (512, 512) + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_enhancer_model') ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('face_enhancer_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + model_name = state_manager.get_item('face_enhancer_model') + return create_static_model_set('full').get(model_name) + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--face-enhancer-model', help = wording.get('help.face_enhancer_model'), default = config.get_str_value('processors', 'face_enhancer_model', 'gfpgan_1.4'), choices = processors_choices.face_enhancer_models) + group_processors.add_argument('--face-enhancer-blend', help = wording.get('help.face_enhancer_blend'), type = int, default = config.get_int_value('processors', 'face_enhancer_blend', '80'), choices = processors_choices.face_enhancer_blend_range, metavar = create_int_metavar(processors_choices.face_enhancer_blend_range)) + group_processors.add_argument('--face-enhancer-weight', help = wording.get('help.face_enhancer_weight'), type = float, default = config.get_float_value('processors', 'face_enhancer_weight', '1.0'), choices = processors_choices.face_enhancer_weight_range, metavar = create_float_metavar(processors_choices.face_enhancer_weight_range)) + facefusion.jobs.job_store.register_step_keys([ 'face_enhancer_model', 'face_enhancer_blend', 'face_enhancer_weight' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('face_enhancer_model', args.get('face_enhancer_model')) + apply_state_item('face_enhancer_blend', args.get('face_enhancer_blend')) + apply_state_item('face_enhancer_weight', args.get('face_enhancer_weight')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def pre_process(mode : ProcessMode) -> bool: + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_masker.clear_inference_pool() + face_recognizer.clear_inference_pool() + + +def enhance_face(target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + model_template = get_model_options().get('template') + model_size = get_model_options().get('size') + crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, model_size) + box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), (0, 0, 0, 0)) + crop_masks =\ + [ + box_mask + ] + + if 'occlusion' in state_manager.get_item('face_mask_types'): + occlusion_mask = create_occlusion_mask(crop_vision_frame) + crop_masks.append(occlusion_mask) + + crop_vision_frame = prepare_crop_frame(crop_vision_frame) + face_enhancer_weight = numpy.array([ state_manager.get_item('face_enhancer_weight') ]).astype(numpy.double) + crop_vision_frame = forward(crop_vision_frame, face_enhancer_weight) + crop_vision_frame = normalize_crop_frame(crop_vision_frame) + crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) + paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix) + temp_vision_frame = blend_frame(temp_vision_frame, paste_vision_frame) + return temp_vision_frame + + +def forward(crop_vision_frame : VisionFrame, face_enhancer_weight : FaceEnhancerWeight) -> VisionFrame: + face_enhancer = get_inference_pool().get('face_enhancer') + face_enhancer_inputs = {} + + for face_enhancer_input in face_enhancer.get_inputs(): + if face_enhancer_input.name == 'input': + face_enhancer_inputs[face_enhancer_input.name] = crop_vision_frame + if face_enhancer_input.name == 'weight': + face_enhancer_inputs[face_enhancer_input.name] = face_enhancer_weight + + with thread_semaphore(): + crop_vision_frame = face_enhancer.run(None, face_enhancer_inputs)[0][0] + + return crop_vision_frame + + +def has_weight_input() -> bool: + face_enhancer = get_inference_pool().get('face_enhancer') + + for deep_swapper_input in face_enhancer.get_inputs(): + if deep_swapper_input.name == 'weight': + return True + + return False + + +def prepare_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + crop_vision_frame = crop_vision_frame[:, :, ::-1] / 255.0 + crop_vision_frame = (crop_vision_frame - 0.5) / 0.5 + crop_vision_frame = numpy.expand_dims(crop_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) + return crop_vision_frame + + +def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + crop_vision_frame = numpy.clip(crop_vision_frame, -1, 1) + crop_vision_frame = (crop_vision_frame + 1) / 2 + crop_vision_frame = crop_vision_frame.transpose(1, 2, 0) + crop_vision_frame = (crop_vision_frame * 255.0).round() + crop_vision_frame = crop_vision_frame.astype(numpy.uint8)[:, :, ::-1] + return crop_vision_frame + + +def blend_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame) -> VisionFrame: + face_enhancer_blend = 1 - (state_manager.get_item('face_enhancer_blend') / 100) + temp_vision_frame = cv2.addWeighted(temp_vision_frame, face_enhancer_blend, paste_vision_frame, 1 - face_enhancer_blend, 0) + return temp_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + return enhance_face(target_face, temp_vision_frame) + + +def process_frame(inputs : FaceEnhancerInputs) -> VisionFrame: + reference_faces = inputs.get('reference_faces') + target_vision_frame = inputs.get('target_vision_frame') + many_faces = sort_and_filter_faces(get_many_faces([ target_vision_frame ])) + + if state_manager.get_item('face_selector_mode') == 'many': + if many_faces: + for target_face in many_faces: + target_vision_frame = enhance_face(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'one': + target_face = get_one_face(many_faces) + if target_face: + target_vision_frame = enhance_face(target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'reference': + similar_faces = find_similar_faces(many_faces, reference_faces, state_manager.get_item('reference_face_distance')) + if similar_faces: + for similar_face in similar_faces: + target_vision_frame = enhance_face(similar_face, target_vision_frame) + return target_vision_frame + + +def process_frames(source_path : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + + for queue_payload in process_manager.manage(queue_payloads): + target_vision_path = queue_payload['frame_path'] + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_path : str, target_path : str, output_path : str) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(None, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/face_swapper.py b/facefusion/processors/modules/face_swapper.py new file mode 100644 index 0000000000000000000000000000000000000000..3cd44452431dd499db53423cb163f9f2a1ba566c --- /dev/null +++ b/facefusion/processors/modules/face_swapper.py @@ -0,0 +1,712 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List, Tuple + +import cv2 +import numpy + +import facefusion.choices +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, video_manager, wording +from facefusion.common_helper import get_first +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.execution import has_execution_provider +from facefusion.face_analyser import get_average_face, get_many_faces, get_one_face +from facefusion.face_helper import paste_back, warp_face_by_face_landmark_5 +from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask, create_region_mask +from facefusion.face_selector import find_similar_faces, sort_and_filter_faces, sort_faces_by_order +from facefusion.face_store import get_reference_faces +from facefusion.filesystem import filter_image_paths, has_image, in_directory, is_image, is_video, resolve_relative_path, same_file_extension +from facefusion.model_helper import get_static_model_initializer +from facefusion.processors import choices as processors_choices +from facefusion.processors.pixel_boost import explode_pixel_boost, implode_pixel_boost +from facefusion.processors.types import FaceSwapperInputs +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import conditional_thread_semaphore +from facefusion.types import ApplyStateItem, Args, DownloadScope, Embedding, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import read_image, read_static_image, read_static_images, unpack_resolution, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'blendswap_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'blendswap_256.hash'), + 'path': resolve_relative_path('../.assets/models/blendswap_256.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'blendswap_256.onnx'), + 'path': resolve_relative_path('../.assets/models/blendswap_256.onnx') + } + }, + 'type': 'blendswap', + 'template': 'ffhq_512', + 'size': (256, 256), + 'mean': [ 0.0, 0.0, 0.0 ], + 'standard_deviation': [ 1.0, 1.0, 1.0 ] + }, + 'ghost_1_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'ghost_1_256.hash'), + 'path': resolve_relative_path('../.assets/models/ghost_1_256.hash') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_ghost.hash'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_ghost.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'ghost_1_256.onnx'), + 'path': resolve_relative_path('../.assets/models/ghost_1_256.onnx') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_ghost.onnx'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_ghost.onnx') + } + }, + 'type': 'ghost', + 'template': 'arcface_112_v1', + 'size': (256, 256), + 'mean': [ 0.5, 0.5, 0.5 ], + 'standard_deviation': [ 0.5, 0.5, 0.5 ] + }, + 'ghost_2_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'ghost_2_256.hash'), + 'path': resolve_relative_path('../.assets/models/ghost_2_256.hash') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_ghost.hash'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_ghost.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'ghost_2_256.onnx'), + 'path': resolve_relative_path('../.assets/models/ghost_2_256.onnx') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_ghost.onnx'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_ghost.onnx') + } + }, + 'type': 'ghost', + 'template': 'arcface_112_v1', + 'size': (256, 256), + 'mean': [ 0.5, 0.5, 0.5 ], + 'standard_deviation': [ 0.5, 0.5, 0.5 ] + }, + 'ghost_3_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'ghost_3_256.hash'), + 'path': resolve_relative_path('../.assets/models/ghost_3_256.hash') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_ghost.hash'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_ghost.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'ghost_3_256.onnx'), + 'path': resolve_relative_path('../.assets/models/ghost_3_256.onnx') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_ghost.onnx'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_ghost.onnx') + } + }, + 'type': 'ghost', + 'template': 'arcface_112_v1', + 'size': (256, 256), + 'mean': [ 0.5, 0.5, 0.5 ], + 'standard_deviation': [ 0.5, 0.5, 0.5 ] + }, + 'hififace_unofficial_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.1.0', 'hififace_unofficial_256.hash'), + 'path': resolve_relative_path('../.assets/models/hififace_unofficial_256.hash') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.1.0', 'arcface_converter_hififace.hash'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_hififace.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.1.0', 'hififace_unofficial_256.onnx'), + 'path': resolve_relative_path('../.assets/models/hififace_unofficial_256.onnx') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.1.0', 'arcface_converter_hififace.onnx'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_hififace.onnx') + } + }, + 'type': 'hififace', + 'template': 'mtcnn_512', + 'size': (256, 256), + 'mean': [ 0.5, 0.5, 0.5 ], + 'standard_deviation': [ 0.5, 0.5, 0.5 ] + }, + 'hyperswap_1a_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.3.0', 'hyperswap_1a_256.hash'), + 'path': resolve_relative_path('../.assets/models/hyperswap_1a_256.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.3.0', 'hyperswap_1a_256.onnx'), + 'path': resolve_relative_path('../.assets/models/hyperswap_1a_256.onnx') + } + }, + 'type': 'hyperswap', + 'template': 'arcface_128', + 'size': (256, 256), + 'mean': [ 0.5, 0.5, 0.5 ], + 'standard_deviation': [ 0.5, 0.5, 0.5 ] + }, + 'hyperswap_1b_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.3.0', 'hyperswap_1b_256.hash'), + 'path': resolve_relative_path('../.assets/models/hyperswap_1b_256.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.3.0', 'hyperswap_1b_256.onnx'), + 'path': resolve_relative_path('../.assets/models/hyperswap_1b_256.onnx') + } + }, + 'type': 'hyperswap', + 'template': 'arcface_128', + 'size': (256, 256), + 'mean': [ 0.5, 0.5, 0.5 ], + 'standard_deviation': [ 0.5, 0.5, 0.5 ] + }, + 'hyperswap_1c_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.3.0', 'hyperswap_1c_256.hash'), + 'path': resolve_relative_path('../.assets/models/hyperswap_1c_256.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.3.0', 'hyperswap_1c_256.onnx'), + 'path': resolve_relative_path('../.assets/models/hyperswap_1c_256.onnx') + } + }, + 'type': 'hyperswap', + 'template': 'arcface_128', + 'size': (256, 256), + 'mean': [ 0.5, 0.5, 0.5 ], + 'standard_deviation': [ 0.5, 0.5, 0.5 ] + }, + 'inswapper_128': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'inswapper_128.hash'), + 'path': resolve_relative_path('../.assets/models/inswapper_128.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'inswapper_128.onnx'), + 'path': resolve_relative_path('../.assets/models/inswapper_128.onnx') + } + }, + 'type': 'inswapper', + 'template': 'arcface_128', + 'size': (128, 128), + 'mean': [ 0.0, 0.0, 0.0 ], + 'standard_deviation': [ 1.0, 1.0, 1.0 ] + }, + 'inswapper_128_fp16': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'inswapper_128_fp16.hash'), + 'path': resolve_relative_path('../.assets/models/inswapper_128_fp16.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'inswapper_128_fp16.onnx'), + 'path': resolve_relative_path('../.assets/models/inswapper_128_fp16.onnx') + } + }, + 'type': 'inswapper', + 'template': 'arcface_128', + 'size': (128, 128), + 'mean': [ 0.0, 0.0, 0.0 ], + 'standard_deviation': [ 1.0, 1.0, 1.0 ] + }, + 'simswap_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'simswap_256.hash'), + 'path': resolve_relative_path('../.assets/models/simswap_256.hash') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_simswap.hash'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_simswap.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'simswap_256.onnx'), + 'path': resolve_relative_path('../.assets/models/simswap_256.onnx') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_simswap.onnx'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_simswap.onnx') + } + }, + 'type': 'simswap', + 'template': 'arcface_112_v1', + 'size': (256, 256), + 'mean': [ 0.485, 0.456, 0.406 ], + 'standard_deviation': [ 0.229, 0.224, 0.225 ] + }, + 'simswap_unofficial_512': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'simswap_unofficial_512.hash'), + 'path': resolve_relative_path('../.assets/models/simswap_unofficial_512.hash') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_simswap.hash'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_simswap.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'simswap_unofficial_512.onnx'), + 'path': resolve_relative_path('../.assets/models/simswap_unofficial_512.onnx') + }, + 'embedding_converter': + { + 'url': resolve_download_url('models-3.0.0', 'arcface_converter_simswap.onnx'), + 'path': resolve_relative_path('../.assets/models/arcface_converter_simswap.onnx') + } + }, + 'type': 'simswap', + 'template': 'arcface_112_v1', + 'size': (512, 512), + 'mean': [ 0.0, 0.0, 0.0 ], + 'standard_deviation': [ 1.0, 1.0, 1.0 ] + }, + 'uniface_256': + { + 'hashes': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'uniface_256.hash'), + 'path': resolve_relative_path('../.assets/models/uniface_256.hash') + } + }, + 'sources': + { + 'face_swapper': + { + 'url': resolve_download_url('models-3.0.0', 'uniface_256.onnx'), + 'path': resolve_relative_path('../.assets/models/uniface_256.onnx') + } + }, + 'type': 'uniface', + 'template': 'ffhq_512', + 'size': (256, 256), + 'mean': [ 0.5, 0.5, 0.5 ], + 'standard_deviation': [ 0.5, 0.5, 0.5 ] + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ get_model_name() ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ get_model_name() ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + model_name = get_model_name() + return create_static_model_set('full').get(model_name) + + +def get_model_name() -> str: + model_name = state_manager.get_item('face_swapper_model') + + if has_execution_provider('coreml') and model_name == 'inswapper_128_fp16': + return 'inswapper_128' + return model_name + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--face-swapper-model', help = wording.get('help.face_swapper_model'), default = config.get_str_value('processors', 'face_swapper_model', 'hyperswap_1a_256'), choices = processors_choices.face_swapper_models) + known_args, _ = program.parse_known_args() + face_swapper_pixel_boost_choices = processors_choices.face_swapper_set.get(known_args.face_swapper_model) + group_processors.add_argument('--face-swapper-pixel-boost', help = wording.get('help.face_swapper_pixel_boost'), default = config.get_str_value('processors', 'face_swapper_pixel_boost', get_first(face_swapper_pixel_boost_choices)), choices = face_swapper_pixel_boost_choices) + facefusion.jobs.job_store.register_step_keys([ 'face_swapper_model', 'face_swapper_pixel_boost' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('face_swapper_model', args.get('face_swapper_model')) + apply_state_item('face_swapper_pixel_boost', args.get('face_swapper_pixel_boost')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def pre_process(mode : ProcessMode) -> bool: + if not has_image(state_manager.get_item('source_paths')): + logger.error(wording.get('choose_image_source') + wording.get('exclamation_mark'), __name__) + return False + source_image_paths = filter_image_paths(state_manager.get_item('source_paths')) + source_frames = read_static_images(source_image_paths) + source_faces = get_many_faces(source_frames) + if not get_one_face(source_faces): + logger.error(wording.get('no_source_face_detected') + wording.get('exclamation_mark'), __name__) + return False + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + get_static_model_initializer.cache_clear() + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_masker.clear_inference_pool() + face_recognizer.clear_inference_pool() + + +def swap_face(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + model_template = get_model_options().get('template') + model_size = get_model_options().get('size') + pixel_boost_size = unpack_resolution(state_manager.get_item('face_swapper_pixel_boost')) + pixel_boost_total = pixel_boost_size[0] // model_size[0] + crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), model_template, pixel_boost_size) + temp_vision_frames = [] + crop_masks = [] + + if 'box' in state_manager.get_item('face_mask_types'): + box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), state_manager.get_item('face_mask_padding')) + crop_masks.append(box_mask) + + if 'occlusion' in state_manager.get_item('face_mask_types'): + occlusion_mask = create_occlusion_mask(crop_vision_frame) + crop_masks.append(occlusion_mask) + + pixel_boost_vision_frames = implode_pixel_boost(crop_vision_frame, pixel_boost_total, model_size) + for pixel_boost_vision_frame in pixel_boost_vision_frames: + pixel_boost_vision_frame = prepare_crop_frame(pixel_boost_vision_frame) + pixel_boost_vision_frame = forward_swap_face(source_face, pixel_boost_vision_frame) + pixel_boost_vision_frame = normalize_crop_frame(pixel_boost_vision_frame) + temp_vision_frames.append(pixel_boost_vision_frame) + crop_vision_frame = explode_pixel_boost(temp_vision_frames, pixel_boost_total, model_size, pixel_boost_size) + + if 'area' in state_manager.get_item('face_mask_types'): + face_landmark_68 = cv2.transform(target_face.landmark_set.get('68').reshape(1, -1, 2), affine_matrix).reshape(-1, 2) + area_mask = create_area_mask(crop_vision_frame, face_landmark_68, state_manager.get_item('face_mask_areas')) + crop_masks.append(area_mask) + + if 'region' in state_manager.get_item('face_mask_types'): + region_mask = create_region_mask(crop_vision_frame, state_manager.get_item('face_mask_regions')) + crop_masks.append(region_mask) + + crop_mask = numpy.minimum.reduce(crop_masks).clip(0, 1) + temp_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix) + return temp_vision_frame + + +def forward_swap_face(source_face : Face, crop_vision_frame : VisionFrame) -> VisionFrame: + face_swapper = get_inference_pool().get('face_swapper') + model_type = get_model_options().get('type') + face_swapper_inputs = {} + + if has_execution_provider('coreml') and model_type in [ 'ghost', 'uniface' ]: + face_swapper.set_providers([ facefusion.choices.execution_provider_set.get('cpu') ]) + + for face_swapper_input in face_swapper.get_inputs(): + if face_swapper_input.name == 'source': + if model_type in [ 'blendswap', 'uniface' ]: + face_swapper_inputs[face_swapper_input.name] = prepare_source_frame(source_face) + else: + face_swapper_inputs[face_swapper_input.name] = prepare_source_embedding(source_face) + if face_swapper_input.name == 'target': + face_swapper_inputs[face_swapper_input.name] = crop_vision_frame + + with conditional_thread_semaphore(): + crop_vision_frame = face_swapper.run(None, face_swapper_inputs)[0][0] + + return crop_vision_frame + + +def forward_convert_embedding(embedding : Embedding) -> Embedding: + embedding_converter = get_inference_pool().get('embedding_converter') + + with conditional_thread_semaphore(): + embedding = embedding_converter.run(None, + { + 'input': embedding + })[0] + + return embedding + + +def prepare_source_frame(source_face : Face) -> VisionFrame: + model_type = get_model_options().get('type') + source_vision_frame = read_static_image(get_first(state_manager.get_item('source_paths'))) + + if model_type == 'blendswap': + source_vision_frame, _ = warp_face_by_face_landmark_5(source_vision_frame, source_face.landmark_set.get('5/68'), 'arcface_112_v2', (112, 112)) + if model_type == 'uniface': + source_vision_frame, _ = warp_face_by_face_landmark_5(source_vision_frame, source_face.landmark_set.get('5/68'), 'ffhq_512', (256, 256)) + source_vision_frame = source_vision_frame[:, :, ::-1] / 255.0 + source_vision_frame = source_vision_frame.transpose(2, 0, 1) + source_vision_frame = numpy.expand_dims(source_vision_frame, axis = 0).astype(numpy.float32) + return source_vision_frame + + +def prepare_source_embedding(source_face : Face) -> Embedding: + model_type = get_model_options().get('type') + + if model_type == 'ghost': + source_embedding, _ = convert_embedding(source_face) + source_embedding = source_embedding.reshape(1, -1) + return source_embedding + + if model_type == 'hyperswap': + source_embedding = source_face.normed_embedding.reshape((1, -1)) + return source_embedding + + if model_type == 'inswapper': + model_path = get_model_options().get('sources').get('face_swapper').get('path') + model_initializer = get_static_model_initializer(model_path) + source_embedding = source_face.embedding.reshape((1, -1)) + source_embedding = numpy.dot(source_embedding, model_initializer) / numpy.linalg.norm(source_embedding) + return source_embedding + + _, source_normed_embedding = convert_embedding(source_face) + source_embedding = source_normed_embedding.reshape(1, -1) + return source_embedding + + +def convert_embedding(source_face : Face) -> Tuple[Embedding, Embedding]: + embedding = source_face.embedding.reshape(-1, 512) + embedding = forward_convert_embedding(embedding) + embedding = embedding.ravel() + normed_embedding = embedding / numpy.linalg.norm(embedding) + return embedding, normed_embedding + + +def prepare_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + model_mean = get_model_options().get('mean') + model_standard_deviation = get_model_options().get('standard_deviation') + + crop_vision_frame = crop_vision_frame[:, :, ::-1] / 255.0 + crop_vision_frame = (crop_vision_frame - model_mean) / model_standard_deviation + crop_vision_frame = crop_vision_frame.transpose(2, 0, 1) + crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0).astype(numpy.float32) + return crop_vision_frame + + +def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + model_type = get_model_options().get('type') + model_mean = get_model_options().get('mean') + model_standard_deviation = get_model_options().get('standard_deviation') + + crop_vision_frame = crop_vision_frame.transpose(1, 2, 0) + if model_type in [ 'ghost', 'hififace', 'hyperswap', 'uniface' ]: + crop_vision_frame = crop_vision_frame * model_standard_deviation + model_mean + crop_vision_frame = crop_vision_frame.clip(0, 1) + crop_vision_frame = crop_vision_frame[:, :, ::-1] * 255 + return crop_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + return swap_face(source_face, target_face, temp_vision_frame) + + +def process_frame(inputs : FaceSwapperInputs) -> VisionFrame: + reference_faces = inputs.get('reference_faces') + source_face = inputs.get('source_face') + target_vision_frame = inputs.get('target_vision_frame') + many_faces = sort_and_filter_faces(get_many_faces([ target_vision_frame ])) + + if state_manager.get_item('face_selector_mode') == 'many': + if many_faces: + for target_face in many_faces: + target_vision_frame = swap_face(source_face, target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'one': + target_face = get_one_face(many_faces) + if target_face: + target_vision_frame = swap_face(source_face, target_face, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'reference': + similar_faces = find_similar_faces(many_faces, reference_faces, state_manager.get_item('reference_face_distance')) + if similar_faces: + for similar_face in similar_faces: + target_vision_frame = swap_face(source_face, similar_face, target_vision_frame) + return target_vision_frame + + +def process_frames(source_paths : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + source_frames = read_static_images(source_paths) + source_faces = [] + + for source_frame in source_frames: + temp_faces = get_many_faces([ source_frame ]) + temp_faces = sort_faces_by_order(temp_faces, 'large-small') + if temp_faces: + source_faces.append(get_first(temp_faces)) + source_face = get_average_face(source_faces) + + for queue_payload in process_manager.manage(queue_payloads): + target_vision_path = queue_payload['frame_path'] + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'source_face': source_face, + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_paths : List[str], target_path : str, output_path : str) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + source_frames = read_static_images(source_paths) + source_faces = [] + + for source_frame in source_frames: + temp_faces = get_many_faces([ source_frame ]) + temp_faces = sort_faces_by_order(temp_faces, 'large-small') + if temp_faces: + source_faces.append(get_first(temp_faces)) + source_face = get_average_face(source_faces) + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'source_face': source_face, + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(source_paths, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/frame_colorizer.py b/facefusion/processors/modules/frame_colorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..abf3fc91410194e7ad3098b56b694ac35a0edcbf --- /dev/null +++ b/facefusion/processors/modules/frame_colorizer.py @@ -0,0 +1,295 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List + +import cv2 +import numpy + +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, video_manager, wording +from facefusion.common_helper import create_int_metavar +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.execution import has_execution_provider +from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.types import FrameColorizerInputs +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import thread_semaphore +from facefusion.types import ApplyStateItem, Args, DownloadScope, ExecutionProvider, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import read_image, read_static_image, unpack_resolution, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'ddcolor': + { + 'hashes': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'ddcolor.hash'), + 'path': resolve_relative_path('../.assets/models/ddcolor.hash') + } + }, + 'sources': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'ddcolor.onnx'), + 'path': resolve_relative_path('../.assets/models/ddcolor.onnx') + } + }, + 'type': 'ddcolor' + }, + 'ddcolor_artistic': + { + 'hashes': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'ddcolor_artistic.hash'), + 'path': resolve_relative_path('../.assets/models/ddcolor_artistic.hash') + } + }, + 'sources': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'ddcolor_artistic.onnx'), + 'path': resolve_relative_path('../.assets/models/ddcolor_artistic.onnx') + } + }, + 'type': 'ddcolor' + }, + 'deoldify': + { + 'hashes': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'deoldify.hash'), + 'path': resolve_relative_path('../.assets/models/deoldify.hash') + } + }, + 'sources': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'deoldify.onnx'), + 'path': resolve_relative_path('../.assets/models/deoldify.onnx') + } + }, + 'type': 'deoldify' + }, + 'deoldify_artistic': + { + 'hashes': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'deoldify_artistic.hash'), + 'path': resolve_relative_path('../.assets/models/deoldify_artistic.hash') + } + }, + 'sources': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'deoldify_artistic.onnx'), + 'path': resolve_relative_path('../.assets/models/deoldify_artistic.onnx') + } + }, + 'type': 'deoldify' + }, + 'deoldify_stable': + { + 'hashes': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'deoldify_stable.hash'), + 'path': resolve_relative_path('../.assets/models/deoldify_stable.hash') + } + }, + 'sources': + { + 'frame_colorizer': + { + 'url': resolve_download_url('models-3.0.0', 'deoldify_stable.onnx'), + 'path': resolve_relative_path('../.assets/models/deoldify_stable.onnx') + } + }, + 'type': 'deoldify' + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('frame_colorizer_model') ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('frame_colorizer_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def resolve_execution_providers() -> List[ExecutionProvider]: + if has_execution_provider('coreml'): + return [ 'cpu' ] + return state_manager.get_item('execution_providers') + + +def get_model_options() -> ModelOptions: + model_name = state_manager.get_item('frame_colorizer_model') + return create_static_model_set('full').get(model_name) + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--frame-colorizer-model', help = wording.get('help.frame_colorizer_model'), default = config.get_str_value('processors', 'frame_colorizer_model', 'ddcolor'), choices = processors_choices.frame_colorizer_models) + group_processors.add_argument('--frame-colorizer-size', help = wording.get('help.frame_colorizer_size'), type = str, default = config.get_str_value('processors', 'frame_colorizer_size', '256x256'), choices = processors_choices.frame_colorizer_sizes) + group_processors.add_argument('--frame-colorizer-blend', help = wording.get('help.frame_colorizer_blend'), type = int, default = config.get_int_value('processors', 'frame_colorizer_blend', '100'), choices = processors_choices.frame_colorizer_blend_range, metavar = create_int_metavar(processors_choices.frame_colorizer_blend_range)) + facefusion.jobs.job_store.register_step_keys([ 'frame_colorizer_model', 'frame_colorizer_blend', 'frame_colorizer_size' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('frame_colorizer_model', args.get('frame_colorizer_model')) + apply_state_item('frame_colorizer_blend', args.get('frame_colorizer_blend')) + apply_state_item('frame_colorizer_size', args.get('frame_colorizer_size')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def pre_process(mode : ProcessMode) -> bool: + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + + +def colorize_frame(temp_vision_frame : VisionFrame) -> VisionFrame: + color_vision_frame = prepare_temp_frame(temp_vision_frame) + color_vision_frame = forward(color_vision_frame) + color_vision_frame = merge_color_frame(temp_vision_frame, color_vision_frame) + color_vision_frame = blend_frame(temp_vision_frame, color_vision_frame) + return color_vision_frame + + +def forward(color_vision_frame : VisionFrame) -> VisionFrame: + frame_colorizer = get_inference_pool().get('frame_colorizer') + + with thread_semaphore(): + color_vision_frame = frame_colorizer.run(None, + { + 'input': color_vision_frame + })[0][0] + + return color_vision_frame + + +def prepare_temp_frame(temp_vision_frame : VisionFrame) -> VisionFrame: + model_size = unpack_resolution(state_manager.get_item('frame_colorizer_size')) + model_type = get_model_options().get('type') + temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_BGR2GRAY) + temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_GRAY2RGB) + + if model_type == 'ddcolor': + temp_vision_frame = (temp_vision_frame / 255.0).astype(numpy.float32) #type:ignore[operator] + temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_RGB2LAB)[:, :, :1] + temp_vision_frame = numpy.concatenate((temp_vision_frame, numpy.zeros_like(temp_vision_frame), numpy.zeros_like(temp_vision_frame)), axis = -1) + temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_LAB2RGB) + + temp_vision_frame = cv2.resize(temp_vision_frame, model_size) + temp_vision_frame = temp_vision_frame.transpose((2, 0, 1)) + temp_vision_frame = numpy.expand_dims(temp_vision_frame, axis = 0).astype(numpy.float32) + return temp_vision_frame + + +def merge_color_frame(temp_vision_frame : VisionFrame, color_vision_frame : VisionFrame) -> VisionFrame: + model_type = get_model_options().get('type') + color_vision_frame = color_vision_frame.transpose(1, 2, 0) + color_vision_frame = cv2.resize(color_vision_frame, (temp_vision_frame.shape[1], temp_vision_frame.shape[0])) + + if model_type == 'ddcolor': + temp_vision_frame = (temp_vision_frame / 255.0).astype(numpy.float32) + temp_vision_frame = cv2.cvtColor(temp_vision_frame, cv2.COLOR_BGR2LAB)[:, :, :1] + color_vision_frame = numpy.concatenate((temp_vision_frame, color_vision_frame), axis = -1) + color_vision_frame = cv2.cvtColor(color_vision_frame, cv2.COLOR_LAB2BGR) + color_vision_frame = (color_vision_frame * 255.0).round().astype(numpy.uint8) #type:ignore[operator] + + if model_type == 'deoldify': + temp_blue_channel, _, _ = cv2.split(temp_vision_frame) + color_vision_frame = cv2.cvtColor(color_vision_frame, cv2.COLOR_BGR2RGB).astype(numpy.uint8) + color_vision_frame = cv2.cvtColor(color_vision_frame, cv2.COLOR_BGR2LAB) + _, color_green_channel, color_red_channel = cv2.split(color_vision_frame) + color_vision_frame = cv2.merge((temp_blue_channel, color_green_channel, color_red_channel)) + color_vision_frame = cv2.cvtColor(color_vision_frame, cv2.COLOR_LAB2BGR) + return color_vision_frame + + +def blend_frame(temp_vision_frame : VisionFrame, paste_vision_frame : VisionFrame) -> VisionFrame: + frame_colorizer_blend = 1 - (state_manager.get_item('frame_colorizer_blend') / 100) + temp_vision_frame = cv2.addWeighted(temp_vision_frame, frame_colorizer_blend, paste_vision_frame, 1 - frame_colorizer_blend, 0) + return temp_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + pass + + +def process_frame(inputs : FrameColorizerInputs) -> VisionFrame: + target_vision_frame = inputs.get('target_vision_frame') + return colorize_frame(target_vision_frame) + + +def process_frames(source_paths : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + for queue_payload in process_manager.manage(queue_payloads): + target_vision_path = queue_payload['frame_path'] + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_paths : List[str], target_path : str, output_path : str) -> None: + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(None, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/frame_enhancer.py b/facefusion/processors/modules/frame_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..833e191664fe96eeef7e0eaecc48680e8fa24883 --- /dev/null +++ b/facefusion/processors/modules/frame_enhancer.py @@ -0,0 +1,560 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List + +import cv2 +import numpy + +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, video_manager, wording +from facefusion.common_helper import create_int_metavar +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.execution import has_execution_provider +from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.types import FrameEnhancerInputs +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import conditional_thread_semaphore +from facefusion.types import ApplyStateItem, Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import create_tile_frames, merge_tile_frames, read_image, read_static_image, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'clear_reality_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'clear_reality_x4.hash'), + 'path': resolve_relative_path('../.assets/models/clear_reality_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'clear_reality_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/clear_reality_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'lsdir_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'lsdir_x4.hash'), + 'path': resolve_relative_path('../.assets/models/lsdir_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'lsdir_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/lsdir_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'nomos8k_sc_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'nomos8k_sc_x4.hash'), + 'path': resolve_relative_path('../.assets/models/nomos8k_sc_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'nomos8k_sc_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/nomos8k_sc_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'real_esrgan_x2': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x2.hash'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x2.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x2.onnx'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x2.onnx') + } + }, + 'size': (256, 16, 8), + 'scale': 2 + }, + 'real_esrgan_x2_fp16': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x2_fp16.hash'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x2_fp16.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x2_fp16.onnx'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x2_fp16.onnx') + } + }, + 'size': (256, 16, 8), + 'scale': 2 + }, + 'real_esrgan_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x4.hash'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x4.onnx') + } + }, + 'size': (256, 16, 8), + 'scale': 4 + }, + 'real_esrgan_x4_fp16': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x4_fp16.hash'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x4_fp16.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x4_fp16.onnx'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x4_fp16.onnx') + } + }, + 'size': (256, 16, 8), + 'scale': 4 + }, + 'real_esrgan_x8': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x8.hash'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x8.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x8.onnx'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x8.onnx') + } + }, + 'size': (256, 16, 8), + 'scale': 8 + }, + 'real_esrgan_x8_fp16': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x8_fp16.hash'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x8_fp16.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_esrgan_x8_fp16.onnx'), + 'path': resolve_relative_path('../.assets/models/real_esrgan_x8_fp16.onnx') + } + }, + 'size': (256, 16, 8), + 'scale': 8 + }, + 'real_hatgan_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_hatgan_x4.hash'), + 'path': resolve_relative_path('../.assets/models/real_hatgan_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'real_hatgan_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/real_hatgan_x4.onnx') + } + }, + 'size': (256, 16, 8), + 'scale': 4 + }, + 'real_web_photo_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'real_web_photo_x4.hash'), + 'path': resolve_relative_path('../.assets/models/real_web_photo_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'real_web_photo_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/real_web_photo_x4.onnx') + } + }, + 'size': (64, 4, 2), + 'scale': 4 + }, + 'realistic_rescaler_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'realistic_rescaler_x4.hash'), + 'path': resolve_relative_path('../.assets/models/realistic_rescaler_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'realistic_rescaler_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/realistic_rescaler_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'remacri_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'remacri_x4.hash'), + 'path': resolve_relative_path('../.assets/models/remacri_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'remacri_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/remacri_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'siax_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'siax_x4.hash'), + 'path': resolve_relative_path('../.assets/models/siax_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'siax_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/siax_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'span_kendata_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'span_kendata_x4.hash'), + 'path': resolve_relative_path('../.assets/models/span_kendata_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'span_kendata_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/span_kendata_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'swin2_sr_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'swin2_sr_x4.hash'), + 'path': resolve_relative_path('../.assets/models/swin2_sr_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.1.0', 'swin2_sr_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/swin2_sr_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'ultra_sharp_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'ultra_sharp_x4.hash'), + 'path': resolve_relative_path('../.assets/models/ultra_sharp_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.0.0', 'ultra_sharp_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/ultra_sharp_x4.onnx') + } + }, + 'size': (128, 8, 4), + 'scale': 4 + }, + 'ultra_sharp_2_x4': + { + 'hashes': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.3.0', 'ultra_sharp_2_x4.hash'), + 'path': resolve_relative_path('../.assets/models/ultra_sharp_2_x4.hash') + } + }, + 'sources': + { + 'frame_enhancer': + { + 'url': resolve_download_url('models-3.3.0', 'ultra_sharp_2_x4.onnx'), + 'path': resolve_relative_path('../.assets/models/ultra_sharp_2_x4.onnx') + } + }, + 'size': (1024, 64, 32), + 'scale': 4 + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ get_frame_enhancer_model() ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ get_frame_enhancer_model() ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + model_name = get_frame_enhancer_model() + return create_static_model_set('full').get(model_name) + + +def get_frame_enhancer_model() -> str: + frame_enhancer_model = state_manager.get_item('frame_enhancer_model') + + if has_execution_provider('coreml'): + if frame_enhancer_model == 'real_esrgan_x2_fp16': + return 'real_esrgan_x2' + if frame_enhancer_model == 'real_esrgan_x4_fp16': + return 'real_esrgan_x4' + if frame_enhancer_model == 'real_esrgan_x8_fp16': + return 'real_esrgan_x8' + return frame_enhancer_model + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--frame-enhancer-model', help = wording.get('help.frame_enhancer_model'), default = config.get_str_value('processors', 'frame_enhancer_model', 'span_kendata_x4'), choices = processors_choices.frame_enhancer_models) + group_processors.add_argument('--frame-enhancer-blend', help = wording.get('help.frame_enhancer_blend'), type = int, default = config.get_int_value('processors', 'frame_enhancer_blend', '80'), choices = processors_choices.frame_enhancer_blend_range, metavar = create_int_metavar(processors_choices.frame_enhancer_blend_range)) + facefusion.jobs.job_store.register_step_keys([ 'frame_enhancer_model', 'frame_enhancer_blend' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('frame_enhancer_model', args.get('frame_enhancer_model')) + apply_state_item('frame_enhancer_blend', args.get('frame_enhancer_blend')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def pre_process(mode : ProcessMode) -> bool: + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + + +def enhance_frame(temp_vision_frame : VisionFrame) -> VisionFrame: + model_size = get_model_options().get('size') + model_scale = get_model_options().get('scale') + temp_height, temp_width = temp_vision_frame.shape[:2] + tile_vision_frames, pad_width, pad_height = create_tile_frames(temp_vision_frame, model_size) + + for index, tile_vision_frame in enumerate(tile_vision_frames): + tile_vision_frame = prepare_tile_frame(tile_vision_frame) + tile_vision_frame = forward(tile_vision_frame) + tile_vision_frames[index] = normalize_tile_frame(tile_vision_frame) + + merge_vision_frame = merge_tile_frames(tile_vision_frames, temp_width * model_scale, temp_height * model_scale, pad_width * model_scale, pad_height * model_scale, (model_size[0] * model_scale, model_size[1] * model_scale, model_size[2] * model_scale)) + temp_vision_frame = blend_frame(temp_vision_frame, merge_vision_frame) + return temp_vision_frame + + +def forward(tile_vision_frame : VisionFrame) -> VisionFrame: + frame_enhancer = get_inference_pool().get('frame_enhancer') + + with conditional_thread_semaphore(): + tile_vision_frame = frame_enhancer.run(None, + { + 'input': tile_vision_frame + })[0] + + return tile_vision_frame + + +def prepare_tile_frame(vision_tile_frame : VisionFrame) -> VisionFrame: + vision_tile_frame = numpy.expand_dims(vision_tile_frame[:, :, ::-1], axis = 0) + vision_tile_frame = vision_tile_frame.transpose(0, 3, 1, 2) + vision_tile_frame = vision_tile_frame.astype(numpy.float32) / 255.0 + return vision_tile_frame + + +def normalize_tile_frame(vision_tile_frame : VisionFrame) -> VisionFrame: + vision_tile_frame = vision_tile_frame.transpose(0, 2, 3, 1).squeeze(0) * 255 + vision_tile_frame = vision_tile_frame.clip(0, 255).astype(numpy.uint8)[:, :, ::-1] + return vision_tile_frame + + +def blend_frame(temp_vision_frame : VisionFrame, merge_vision_frame : VisionFrame) -> VisionFrame: + frame_enhancer_blend = 1 - (state_manager.get_item('frame_enhancer_blend') / 100) + temp_vision_frame = cv2.resize(temp_vision_frame, (merge_vision_frame.shape[1], merge_vision_frame.shape[0])) + temp_vision_frame = cv2.addWeighted(temp_vision_frame, frame_enhancer_blend, merge_vision_frame, 1 - frame_enhancer_blend, 0) + return temp_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + pass + + +def process_frame(inputs : FrameEnhancerInputs) -> VisionFrame: + target_vision_frame = inputs.get('target_vision_frame') + return enhance_frame(target_vision_frame) + + +def process_frames(source_paths : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + for queue_payload in process_manager.manage(queue_payloads): + target_vision_path = queue_payload['frame_path'] + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_paths : List[str], target_path : str, output_path : str) -> None: + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + processors.multi_process_frames(None, temp_frame_paths, process_frames) diff --git a/facefusion/processors/modules/lip_syncer.py b/facefusion/processors/modules/lip_syncer.py new file mode 100644 index 0000000000000000000000000000000000000000..16d4b6898681a8a2562f70eec7fef68b45d53c4f --- /dev/null +++ b/facefusion/processors/modules/lip_syncer.py @@ -0,0 +1,348 @@ +from argparse import ArgumentParser +from functools import lru_cache +from typing import List + +import cv2 +import numpy + +import facefusion.jobs.job_manager +import facefusion.jobs.job_store +import facefusion.processors.core as processors +from facefusion import config, content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, inference_manager, logger, process_manager, state_manager, video_manager, voice_extractor, wording +from facefusion.audio import create_empty_audio_frame, get_voice_frame, read_static_voice +from facefusion.common_helper import create_float_metavar +from facefusion.common_helper import get_first +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.face_helper import create_bounding_box, paste_back, warp_face_by_bounding_box, warp_face_by_face_landmark_5 +from facefusion.face_masker import create_area_mask, create_box_mask, create_occlusion_mask +from facefusion.face_selector import find_similar_faces, sort_and_filter_faces +from facefusion.face_store import get_reference_faces +from facefusion.filesystem import filter_audio_paths, has_audio, in_directory, is_image, is_video, resolve_relative_path, same_file_extension +from facefusion.processors import choices as processors_choices +from facefusion.processors.types import LipSyncerInputs, LipSyncerWeight +from facefusion.program_helper import find_argument_group +from facefusion.thread_helper import conditional_thread_semaphore +from facefusion.types import ApplyStateItem, Args, AudioFrame, BoundingBox, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.vision import read_image, read_static_image, restrict_video_fps, write_image + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'edtalk_256': + { + 'hashes': + { + 'lip_syncer': + { + 'url': resolve_download_url('models-3.3.0', 'edtalk_256.hash'), + 'path': resolve_relative_path('../.assets/models/edtalk_256.hash') + } + }, + 'sources': + { + 'lip_syncer': + { + 'url': resolve_download_url('models-3.3.0', 'edtalk_256.onnx'), + 'path': resolve_relative_path('../.assets/models/edtalk_256.onnx') + } + }, + 'type': 'edtalk', + 'size': (256, 256) + }, + 'wav2lip_96': + { + 'hashes': + { + 'lip_syncer': + { + 'url': resolve_download_url('models-3.0.0', 'wav2lip_96.hash'), + 'path': resolve_relative_path('../.assets/models/wav2lip_96.hash') + } + }, + 'sources': + { + 'lip_syncer': + { + 'url': resolve_download_url('models-3.0.0', 'wav2lip_96.onnx'), + 'path': resolve_relative_path('../.assets/models/wav2lip_96.onnx') + } + }, + 'type': 'wav2lip', + 'size': (96, 96) + }, + 'wav2lip_gan_96': + { + 'hashes': + { + 'lip_syncer': + { + 'url': resolve_download_url('models-3.0.0', 'wav2lip_gan_96.hash'), + 'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.hash') + } + }, + 'sources': + { + 'lip_syncer': + { + 'url': resolve_download_url('models-3.0.0', 'wav2lip_gan_96.onnx'), + 'path': resolve_relative_path('../.assets/models/wav2lip_gan_96.onnx') + } + }, + 'type': 'wav2lip', + 'size': (96, 96) + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('lip_syncer_model') ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ state_manager.get_item('lip_syncer_model') ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + model_name = state_manager.get_item('lip_syncer_model') + return create_static_model_set('full').get(model_name) + + +def register_args(program : ArgumentParser) -> None: + group_processors = find_argument_group(program, 'processors') + if group_processors: + group_processors.add_argument('--lip-syncer-model', help = wording.get('help.lip_syncer_model'), default = config.get_str_value('processors', 'lip_syncer_model', 'wav2lip_gan_96'), choices = processors_choices.lip_syncer_models) + group_processors.add_argument('--lip-syncer-weight', help = wording.get('help.lip_syncer_weight'), type = float, default = config.get_float_value('processors', 'lip_syncer_weight', '0.5'), choices = processors_choices.lip_syncer_weight_range, metavar = create_float_metavar(processors_choices.lip_syncer_weight_range)) + facefusion.jobs.job_store.register_step_keys([ 'lip_syncer_model', 'lip_syncer_weight' ]) + + +def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: + apply_state_item('lip_syncer_model', args.get('lip_syncer_model')) + apply_state_item('lip_syncer_weight', args.get('lip_syncer_weight')) + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def pre_process(mode : ProcessMode) -> bool: + if not has_audio(state_manager.get_item('source_paths')): + logger.error(wording.get('choose_audio_source') + wording.get('exclamation_mark'), __name__) + return False + if mode in [ 'output', 'preview' ] and not is_image(state_manager.get_item('target_path')) and not is_video(state_manager.get_item('target_path')): + logger.error(wording.get('choose_image_or_video_target') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not in_directory(state_manager.get_item('output_path')): + logger.error(wording.get('specify_image_or_video_output') + wording.get('exclamation_mark'), __name__) + return False + if mode == 'output' and not same_file_extension(state_manager.get_item('target_path'), state_manager.get_item('output_path')): + logger.error(wording.get('match_target_and_output_extension') + wording.get('exclamation_mark'), __name__) + return False + return True + + +def post_process() -> None: + read_static_image.cache_clear() + read_static_voice.cache_clear() + video_manager.clear_video_pool() + if state_manager.get_item('video_memory_strategy') in [ 'strict', 'moderate' ]: + clear_inference_pool() + if state_manager.get_item('video_memory_strategy') == 'strict': + content_analyser.clear_inference_pool() + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_masker.clear_inference_pool() + face_recognizer.clear_inference_pool() + voice_extractor.clear_inference_pool() + + +def sync_lip(target_face : Face, temp_audio_frame : AudioFrame, temp_vision_frame : VisionFrame) -> VisionFrame: + model_type = get_model_options().get('type') + model_size = get_model_options().get('size') + temp_audio_frame = prepare_audio_frame(temp_audio_frame) + crop_vision_frame, affine_matrix = warp_face_by_face_landmark_5(temp_vision_frame, target_face.landmark_set.get('5/68'), 'ffhq_512', (512, 512)) + crop_masks = [] + + if 'occlusion' in state_manager.get_item('face_mask_types'): + occlusion_mask = create_occlusion_mask(crop_vision_frame) + crop_masks.append(occlusion_mask) + + if model_type == 'edtalk': + lip_syncer_weight = numpy.array([ state_manager.get_item('lip_syncer_weight') ]).astype(numpy.float32) + box_mask = create_box_mask(crop_vision_frame, state_manager.get_item('face_mask_blur'), state_manager.get_item('face_mask_padding')) + crop_masks.append(box_mask) + crop_vision_frame = prepare_crop_frame(crop_vision_frame) + crop_vision_frame = forward_edtalk(temp_audio_frame, crop_vision_frame, lip_syncer_weight) + crop_vision_frame = normalize_crop_frame(crop_vision_frame) + if model_type == 'wav2lip': + face_landmark_68 = cv2.transform(target_face.landmark_set.get('68').reshape(1, -1, 2), affine_matrix).reshape(-1, 2) + area_mask = create_area_mask(crop_vision_frame, face_landmark_68, [ 'lower-face' ]) + crop_masks.append(area_mask) + bounding_box = create_bounding_box(face_landmark_68) + bounding_box = resize_bounding_box(bounding_box, 1 / 8) + area_vision_frame, area_matrix = warp_face_by_bounding_box(crop_vision_frame, bounding_box, model_size) + area_vision_frame = prepare_crop_frame(area_vision_frame) + area_vision_frame = forward_wav2lip(temp_audio_frame, area_vision_frame) + area_vision_frame = normalize_crop_frame(area_vision_frame) + crop_vision_frame = cv2.warpAffine(area_vision_frame, cv2.invertAffineTransform(area_matrix), (512, 512), borderMode = cv2.BORDER_REPLICATE) + + crop_mask = numpy.minimum.reduce(crop_masks) + paste_vision_frame = paste_back(temp_vision_frame, crop_vision_frame, crop_mask, affine_matrix) + return paste_vision_frame + + +def forward_edtalk(temp_audio_frame : AudioFrame, crop_vision_frame : VisionFrame, lip_syncer_weight : LipSyncerWeight) -> VisionFrame: + lip_syncer = get_inference_pool().get('lip_syncer') + + with conditional_thread_semaphore(): + crop_vision_frame = lip_syncer.run(None, + { + 'source': temp_audio_frame, + 'target': crop_vision_frame, + 'weight': lip_syncer_weight + })[0] + + return crop_vision_frame + + +def forward_wav2lip(temp_audio_frame : AudioFrame, area_vision_frame : VisionFrame) -> VisionFrame: + lip_syncer = get_inference_pool().get('lip_syncer') + + with conditional_thread_semaphore(): + area_vision_frame = lip_syncer.run(None, + { + 'source': temp_audio_frame, + 'target': area_vision_frame + })[0] + + return area_vision_frame + + +def prepare_audio_frame(temp_audio_frame : AudioFrame) -> AudioFrame: + model_type = get_model_options().get('type') + temp_audio_frame = numpy.maximum(numpy.exp(-5 * numpy.log(10)), temp_audio_frame) + temp_audio_frame = numpy.log10(temp_audio_frame) * 1.6 + 3.2 + temp_audio_frame = temp_audio_frame.clip(-4, 4).astype(numpy.float32) + + if model_type == 'wav2lip': + temp_audio_frame = temp_audio_frame * state_manager.get_item('lip_syncer_weight') * 2.0 + + temp_audio_frame = numpy.expand_dims(temp_audio_frame, axis = (0, 1)) + return temp_audio_frame + + +def prepare_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + model_type = get_model_options().get('type') + model_size = get_model_options().get('size') + + if model_type == 'edtalk': + crop_vision_frame = cv2.resize(crop_vision_frame, model_size, interpolation = cv2.INTER_AREA) + crop_vision_frame = crop_vision_frame[:, :, ::-1] / 255.0 + crop_vision_frame = numpy.expand_dims(crop_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) + if model_type == 'wav2lip': + crop_vision_frame = numpy.expand_dims(crop_vision_frame, axis = 0) + prepare_vision_frame = crop_vision_frame.copy() + prepare_vision_frame[:, model_size[0] // 2:] = 0 + crop_vision_frame = numpy.concatenate((prepare_vision_frame, crop_vision_frame), axis = 3) + crop_vision_frame = crop_vision_frame.transpose(0, 3, 1, 2).astype('float32') / 255.0 + + return crop_vision_frame + + +def resize_bounding_box(bounding_box : BoundingBox, aspect_ratio : float) -> BoundingBox: + x1, y1, x2, y2 = bounding_box + y1 -= numpy.abs(y2 - y1) * aspect_ratio + bounding_box[1] = max(y1, 0) + return bounding_box + + +def normalize_crop_frame(crop_vision_frame : VisionFrame) -> VisionFrame: + model_type = get_model_options().get('type') + crop_vision_frame = crop_vision_frame[0].transpose(1, 2, 0) + crop_vision_frame = crop_vision_frame.clip(0, 1) * 255 + crop_vision_frame = crop_vision_frame.astype(numpy.uint8) + + if model_type == 'edtalk': + crop_vision_frame = crop_vision_frame[:, :, ::-1] + crop_vision_frame = cv2.resize(crop_vision_frame, (512, 512), interpolation = cv2.INTER_CUBIC) + + return crop_vision_frame + + +def get_reference_frame(source_face : Face, target_face : Face, temp_vision_frame : VisionFrame) -> VisionFrame: + pass + + +def process_frame(inputs : LipSyncerInputs) -> VisionFrame: + reference_faces = inputs.get('reference_faces') + source_audio_frame = inputs.get('source_audio_frame') + target_vision_frame = inputs.get('target_vision_frame') + many_faces = sort_and_filter_faces(get_many_faces([ target_vision_frame ])) + + if state_manager.get_item('face_selector_mode') == 'many': + if many_faces: + for target_face in many_faces: + target_vision_frame = sync_lip(target_face, source_audio_frame, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'one': + target_face = get_one_face(many_faces) + if target_face: + target_vision_frame = sync_lip(target_face, source_audio_frame, target_vision_frame) + if state_manager.get_item('face_selector_mode') == 'reference': + similar_faces = find_similar_faces(many_faces, reference_faces, state_manager.get_item('reference_face_distance')) + if similar_faces: + for similar_face in similar_faces: + target_vision_frame = sync_lip(similar_face, source_audio_frame, target_vision_frame) + return target_vision_frame + + +def process_frames(source_paths : List[str], queue_payloads : List[QueuePayload], update_progress : UpdateProgress) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + source_audio_path = get_first(filter_audio_paths(source_paths)) + temp_video_fps = restrict_video_fps(state_manager.get_item('target_path'), state_manager.get_item('output_video_fps')) + + for queue_payload in process_manager.manage(queue_payloads): + frame_number = queue_payload.get('frame_number') + target_vision_path = queue_payload.get('frame_path') + source_audio_frame = get_voice_frame(source_audio_path, temp_video_fps, frame_number) + if not numpy.any(source_audio_frame): + source_audio_frame = create_empty_audio_frame() + target_vision_frame = read_image(target_vision_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'source_audio_frame': source_audio_frame, + 'target_vision_frame': target_vision_frame + }) + write_image(target_vision_path, output_vision_frame) + update_progress(1) + + +def process_image(source_paths : List[str], target_path : str, output_path : str) -> None: + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + source_audio_frame = create_empty_audio_frame() + target_vision_frame = read_static_image(target_path) + output_vision_frame = process_frame( + { + 'reference_faces': reference_faces, + 'source_audio_frame': source_audio_frame, + 'target_vision_frame': target_vision_frame + }) + write_image(output_path, output_vision_frame) + + +def process_video(source_paths : List[str], temp_frame_paths : List[str]) -> None: + source_audio_paths = filter_audio_paths(state_manager.get_item('source_paths')) + temp_video_fps = restrict_video_fps(state_manager.get_item('target_path'), state_manager.get_item('output_video_fps')) + for source_audio_path in source_audio_paths: + read_static_voice(source_audio_path, temp_video_fps) + processors.multi_process_frames(source_paths, temp_frame_paths, process_frames) diff --git a/facefusion/processors/pixel_boost.py b/facefusion/processors/pixel_boost.py new file mode 100644 index 0000000000000000000000000000000000000000..3b857d14854fabaeaf5952f219483db480c1ea56 --- /dev/null +++ b/facefusion/processors/pixel_boost.py @@ -0,0 +1,18 @@ +from typing import List + +import numpy +from cv2.typing import Size + +from facefusion.types import VisionFrame + + +def implode_pixel_boost(crop_vision_frame : VisionFrame, pixel_boost_total : int, model_size : Size) -> VisionFrame: + pixel_boost_vision_frame = crop_vision_frame.reshape(model_size[0], pixel_boost_total, model_size[1], pixel_boost_total, 3) + pixel_boost_vision_frame = pixel_boost_vision_frame.transpose(1, 3, 0, 2, 4).reshape(pixel_boost_total ** 2, model_size[0], model_size[1], 3) + return pixel_boost_vision_frame + + +def explode_pixel_boost(temp_vision_frames : List[VisionFrame], pixel_boost_total : int, model_size : Size, pixel_boost_size : Size) -> VisionFrame: + crop_vision_frame = numpy.stack(temp_vision_frames).reshape(pixel_boost_total, pixel_boost_total, model_size[0], model_size[1], 3) + crop_vision_frame = crop_vision_frame.transpose(2, 0, 3, 1, 4).reshape(pixel_boost_size[0], pixel_boost_size[1], 3) + return crop_vision_frame diff --git a/facefusion/processors/types.py b/facefusion/processors/types.py new file mode 100644 index 0000000000000000000000000000000000000000..bf9b9f668df2d6f9746c9f711ead18329cc8ec70 --- /dev/null +++ b/facefusion/processors/types.py @@ -0,0 +1,159 @@ +from typing import Any, Dict, List, Literal, TypeAlias, TypedDict + +from numpy.typing import NDArray + +from facefusion.types import AppContext, AudioFrame, Face, FaceSet, VisionFrame + +AgeModifierModel = Literal['styleganex_age'] +DeepSwapperModel : TypeAlias = str +ExpressionRestorerModel = Literal['live_portrait'] +FaceDebuggerItem = Literal['bounding-box', 'face-landmark-5', 'face-landmark-5/68', 'face-landmark-68', 'face-landmark-68/5', 'face-mask', 'face-detector-score', 'face-landmarker-score', 'age', 'gender', 'race'] +FaceEditorModel = Literal['live_portrait'] +FaceEnhancerModel = Literal['codeformer', 'gfpgan_1.2', 'gfpgan_1.3', 'gfpgan_1.4', 'gpen_bfr_256', 'gpen_bfr_512', 'gpen_bfr_1024', 'gpen_bfr_2048', 'restoreformer_plus_plus'] +FaceSwapperModel = Literal['blendswap_256', 'ghost_1_256', 'ghost_2_256', 'ghost_3_256', 'hififace_unofficial_256', 'hyperswap_1a_256', 'hyperswap_1b_256', 'hyperswap_1c_256', 'inswapper_128', 'inswapper_128_fp16', 'simswap_256', 'simswap_unofficial_512', 'uniface_256'] +FrameColorizerModel = Literal['ddcolor', 'ddcolor_artistic', 'deoldify', 'deoldify_artistic', 'deoldify_stable'] +FrameEnhancerModel = Literal['clear_reality_x4', 'lsdir_x4', 'nomos8k_sc_x4', 'real_esrgan_x2', 'real_esrgan_x2_fp16', 'real_esrgan_x4', 'real_esrgan_x4_fp16', 'real_esrgan_x8', 'real_esrgan_x8_fp16', 'real_hatgan_x4', 'real_web_photo_x4', 'realistic_rescaler_x4', 'remacri_x4', 'siax_x4', 'span_kendata_x4', 'swin2_sr_x4', 'ultra_sharp_x4', 'ultra_sharp_2_x4'] +LipSyncerModel = Literal['edtalk_256', 'wav2lip_96', 'wav2lip_gan_96'] + +FaceSwapperSet : TypeAlias = Dict[FaceSwapperModel, List[str]] + +AgeModifierInputs = TypedDict('AgeModifierInputs', +{ + 'reference_faces' : FaceSet, + 'target_vision_frame' : VisionFrame +}) +DeepSwapperInputs = TypedDict('DeepSwapperInputs', +{ + 'reference_faces' : FaceSet, + 'target_vision_frame' : VisionFrame +}) +ExpressionRestorerInputs = TypedDict('ExpressionRestorerInputs', +{ + 'reference_faces' : FaceSet, + 'source_vision_frame' : VisionFrame, + 'target_vision_frame' : VisionFrame +}) +FaceDebuggerInputs = TypedDict('FaceDebuggerInputs', +{ + 'reference_faces' : FaceSet, + 'target_vision_frame' : VisionFrame +}) +FaceEditorInputs = TypedDict('FaceEditorInputs', +{ + 'reference_faces' : FaceSet, + 'target_vision_frame' : VisionFrame +}) +FaceEnhancerInputs = TypedDict('FaceEnhancerInputs', +{ + 'reference_faces' : FaceSet, + 'target_vision_frame' : VisionFrame +}) +FaceSwapperInputs = TypedDict('FaceSwapperInputs', +{ + 'reference_faces' : FaceSet, + 'source_face' : Face, + 'target_vision_frame' : VisionFrame +}) +FrameColorizerInputs = TypedDict('FrameColorizerInputs', +{ + 'target_vision_frame' : VisionFrame +}) +FrameEnhancerInputs = TypedDict('FrameEnhancerInputs', +{ + 'target_vision_frame' : VisionFrame +}) +LipSyncerInputs = TypedDict('LipSyncerInputs', +{ + 'reference_faces' : FaceSet, + 'source_audio_frame' : AudioFrame, + 'target_vision_frame' : VisionFrame +}) + +ProcessorStateKey = Literal\ +[ + 'age_modifier_model', + 'age_modifier_direction', + 'deep_swapper_model', + 'deep_swapper_morph', + 'expression_restorer_model', + 'expression_restorer_factor', + 'face_debugger_items', + 'face_editor_model', + 'face_editor_eyebrow_direction', + 'face_editor_eye_gaze_horizontal', + 'face_editor_eye_gaze_vertical', + 'face_editor_eye_open_ratio', + 'face_editor_lip_open_ratio', + 'face_editor_mouth_grim', + 'face_editor_mouth_pout', + 'face_editor_mouth_purse', + 'face_editor_mouth_smile', + 'face_editor_mouth_position_horizontal', + 'face_editor_mouth_position_vertical', + 'face_editor_head_pitch', + 'face_editor_head_yaw', + 'face_editor_head_roll', + 'face_enhancer_model', + 'face_enhancer_blend', + 'face_enhancer_weight', + 'face_swapper_model', + 'face_swapper_pixel_boost', + 'frame_colorizer_model', + 'frame_colorizer_size', + 'frame_colorizer_blend', + 'frame_enhancer_model', + 'frame_enhancer_blend', + 'lip_syncer_model', + 'lip_syncer_weight' +] +ProcessorState = TypedDict('ProcessorState', +{ + 'age_modifier_model' : AgeModifierModel, + 'age_modifier_direction' : int, + 'deep_swapper_model' : DeepSwapperModel, + 'deep_swapper_morph' : int, + 'expression_restorer_model' : ExpressionRestorerModel, + 'expression_restorer_factor' : int, + 'face_debugger_items' : List[FaceDebuggerItem], + 'face_editor_model' : FaceEditorModel, + 'face_editor_eyebrow_direction' : float, + 'face_editor_eye_gaze_horizontal' : float, + 'face_editor_eye_gaze_vertical' : float, + 'face_editor_eye_open_ratio' : float, + 'face_editor_lip_open_ratio' : float, + 'face_editor_mouth_grim' : float, + 'face_editor_mouth_pout' : float, + 'face_editor_mouth_purse' : float, + 'face_editor_mouth_smile' : float, + 'face_editor_mouth_position_horizontal' : float, + 'face_editor_mouth_position_vertical' : float, + 'face_editor_head_pitch' : float, + 'face_editor_head_yaw' : float, + 'face_editor_head_roll' : float, + 'face_enhancer_model' : FaceEnhancerModel, + 'face_enhancer_blend' : int, + 'face_enhancer_weight' : float, + 'face_swapper_model' : FaceSwapperModel, + 'face_swapper_pixel_boost' : str, + 'frame_colorizer_model' : FrameColorizerModel, + 'frame_colorizer_size' : str, + 'frame_colorizer_blend' : int, + 'frame_enhancer_model' : FrameEnhancerModel, + 'frame_enhancer_blend' : int, + 'lip_syncer_model' : LipSyncerModel +}) +ProcessorStateSet : TypeAlias = Dict[AppContext, ProcessorState] + +AgeModifierDirection : TypeAlias = NDArray[Any] +DeepSwapperMorph : TypeAlias = NDArray[Any] +FaceEnhancerWeight : TypeAlias = NDArray[Any] +LipSyncerWeight : TypeAlias = NDArray[Any] +LivePortraitPitch : TypeAlias = float +LivePortraitYaw : TypeAlias = float +LivePortraitRoll : TypeAlias = float +LivePortraitExpression : TypeAlias = NDArray[Any] +LivePortraitFeatureVolume : TypeAlias = NDArray[Any] +LivePortraitMotionPoints : TypeAlias = NDArray[Any] +LivePortraitRotation : TypeAlias = NDArray[Any] +LivePortraitScale : TypeAlias = NDArray[Any] +LivePortraitTranslation : TypeAlias = NDArray[Any] diff --git a/facefusion/program.py b/facefusion/program.py new file mode 100644 index 0000000000000000000000000000000000000000..5d3d62db31f14efde7e77a82737983b36b486c19 --- /dev/null +++ b/facefusion/program.py @@ -0,0 +1,317 @@ +import tempfile +from argparse import ArgumentParser, HelpFormatter + +import facefusion.choices +from facefusion import config, metadata, state_manager, wording +from facefusion.common_helper import create_float_metavar, create_int_metavar, get_first, get_last +from facefusion.execution import get_available_execution_providers +from facefusion.ffmpeg import get_available_encoder_set +from facefusion.filesystem import get_file_name, resolve_file_paths +from facefusion.jobs import job_store +from facefusion.processors.core import get_processors_modules + + +def create_help_formatter_small(prog : str) -> HelpFormatter: + return HelpFormatter(prog, max_help_position = 50) + + +def create_help_formatter_large(prog : str) -> HelpFormatter: + return HelpFormatter(prog, max_help_position = 300) + + +def create_config_path_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_paths = program.add_argument_group('paths') + group_paths.add_argument('--config-path', help = wording.get('help.config_path'), default = 'facefusion.ini') + job_store.register_job_keys([ 'config_path' ]) + apply_config_path(program) + return program + + +def create_temp_path_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_paths = program.add_argument_group('paths') + group_paths.add_argument('--temp-path', help = wording.get('help.temp_path'), default = config.get_str_value('paths', 'temp_path', tempfile.gettempdir())) + job_store.register_job_keys([ 'temp_path' ]) + return program + + +def create_jobs_path_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_paths = program.add_argument_group('paths') + group_paths.add_argument('--jobs-path', help = wording.get('help.jobs_path'), default = config.get_str_value('paths', 'jobs_path', '.jobs')) + job_store.register_job_keys([ 'jobs_path' ]) + return program + + +def create_source_paths_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_paths = program.add_argument_group('paths') + group_paths.add_argument('-s', '--source-paths', help = wording.get('help.source_paths'), default = config.get_str_list('paths', 'source_paths'), nargs = '+') + job_store.register_step_keys([ 'source_paths' ]) + return program + + +def create_target_path_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_paths = program.add_argument_group('paths') + group_paths.add_argument('-t', '--target-path', help = wording.get('help.target_path'), default = config.get_str_value('paths', 'target_path')) + job_store.register_step_keys([ 'target_path' ]) + return program + + +def create_output_path_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_paths = program.add_argument_group('paths') + group_paths.add_argument('-o', '--output-path', help = wording.get('help.output_path'), default = config.get_str_value('paths', 'output_path')) + job_store.register_step_keys([ 'output_path' ]) + return program + + +def create_source_pattern_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_patterns = program.add_argument_group('patterns') + group_patterns.add_argument('-s', '--source-pattern', help = wording.get('help.source_pattern'), default = config.get_str_value('patterns', 'source_pattern')) + job_store.register_job_keys([ 'source_pattern' ]) + return program + + +def create_target_pattern_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_patterns = program.add_argument_group('patterns') + group_patterns.add_argument('-t', '--target-pattern', help = wording.get('help.target_pattern'), default = config.get_str_value('patterns', 'target_pattern')) + job_store.register_job_keys([ 'target_pattern' ]) + return program + + +def create_output_pattern_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_patterns = program.add_argument_group('patterns') + group_patterns.add_argument('-o', '--output-pattern', help = wording.get('help.output_pattern'), default = config.get_str_value('patterns', 'output_pattern')) + job_store.register_job_keys([ 'output_pattern' ]) + return program + + +def create_face_detector_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_face_detector = program.add_argument_group('face detector') + group_face_detector.add_argument('--face-detector-model', help = wording.get('help.face_detector_model'), default = config.get_str_value('face_detector', 'face_detector_model', 'yolo_face'), choices = facefusion.choices.face_detector_models) + known_args, _ = program.parse_known_args() + face_detector_size_choices = facefusion.choices.face_detector_set.get(known_args.face_detector_model) + group_face_detector.add_argument('--face-detector-size', help = wording.get('help.face_detector_size'), default = config.get_str_value('face_detector', 'face_detector_size', get_last(face_detector_size_choices)), choices = face_detector_size_choices) + group_face_detector.add_argument('--face-detector-angles', help = wording.get('help.face_detector_angles'), type = int, default = config.get_int_list('face_detector', 'face_detector_angles', '0'), choices = facefusion.choices.face_detector_angles, nargs = '+', metavar = 'FACE_DETECTOR_ANGLES') + group_face_detector.add_argument('--face-detector-score', help = wording.get('help.face_detector_score'), type = float, default = config.get_float_value('face_detector', 'face_detector_score', '0.5'), choices = facefusion.choices.face_detector_score_range, metavar = create_float_metavar(facefusion.choices.face_detector_score_range)) + job_store.register_step_keys([ 'face_detector_model', 'face_detector_angles', 'face_detector_size', 'face_detector_score' ]) + return program + + +def create_face_landmarker_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_face_landmarker = program.add_argument_group('face landmarker') + group_face_landmarker.add_argument('--face-landmarker-model', help = wording.get('help.face_landmarker_model'), default = config.get_str_value('face_landmarker', 'face_landmarker_model', '2dfan4'), choices = facefusion.choices.face_landmarker_models) + group_face_landmarker.add_argument('--face-landmarker-score', help = wording.get('help.face_landmarker_score'), type = float, default = config.get_float_value('face_landmarker', 'face_landmarker_score', '0.5'), choices = facefusion.choices.face_landmarker_score_range, metavar = create_float_metavar(facefusion.choices.face_landmarker_score_range)) + job_store.register_step_keys([ 'face_landmarker_model', 'face_landmarker_score' ]) + return program + + +def create_face_selector_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_face_selector = program.add_argument_group('face selector') + group_face_selector.add_argument('--face-selector-mode', help = wording.get('help.face_selector_mode'), default = config.get_str_value('face_selector', 'face_selector_mode', 'reference'), choices = facefusion.choices.face_selector_modes) + group_face_selector.add_argument('--face-selector-order', help = wording.get('help.face_selector_order'), default = config.get_str_value('face_selector', 'face_selector_order', 'large-small'), choices = facefusion.choices.face_selector_orders) + group_face_selector.add_argument('--face-selector-age-start', help = wording.get('help.face_selector_age_start'), type = int, default = config.get_int_value('face_selector', 'face_selector_age_start'), choices = facefusion.choices.face_selector_age_range, metavar = create_int_metavar(facefusion.choices.face_selector_age_range)) + group_face_selector.add_argument('--face-selector-age-end', help = wording.get('help.face_selector_age_end'), type = int, default = config.get_int_value('face_selector', 'face_selector_age_end'), choices = facefusion.choices.face_selector_age_range, metavar = create_int_metavar(facefusion.choices.face_selector_age_range)) + group_face_selector.add_argument('--face-selector-gender', help = wording.get('help.face_selector_gender'), default = config.get_str_value('face_selector', 'face_selector_gender'), choices = facefusion.choices.face_selector_genders) + group_face_selector.add_argument('--face-selector-race', help = wording.get('help.face_selector_race'), default = config.get_str_value('face_selector', 'face_selector_race'), choices = facefusion.choices.face_selector_races) + group_face_selector.add_argument('--reference-face-position', help = wording.get('help.reference_face_position'), type = int, default = config.get_int_value('face_selector', 'reference_face_position', '0')) + group_face_selector.add_argument('--reference-face-distance', help = wording.get('help.reference_face_distance'), type = float, default = config.get_float_value('face_selector', 'reference_face_distance', '0.3'), choices = facefusion.choices.reference_face_distance_range, metavar = create_float_metavar(facefusion.choices.reference_face_distance_range)) + group_face_selector.add_argument('--reference-frame-number', help = wording.get('help.reference_frame_number'), type = int, default = config.get_int_value('face_selector', 'reference_frame_number', '0')) + job_store.register_step_keys([ 'face_selector_mode', 'face_selector_order', 'face_selector_gender', 'face_selector_race', 'face_selector_age_start', 'face_selector_age_end', 'reference_face_position', 'reference_face_distance', 'reference_frame_number' ]) + return program + + +def create_face_masker_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_face_masker = program.add_argument_group('face masker') + group_face_masker.add_argument('--face-occluder-model', help = wording.get('help.face_occluder_model'), default = config.get_str_value('face_masker', 'face_occluder_model', 'xseg_1'), choices = facefusion.choices.face_occluder_models) + group_face_masker.add_argument('--face-parser-model', help = wording.get('help.face_parser_model'), default = config.get_str_value('face_masker', 'face_parser_model', 'bisenet_resnet_34'), choices = facefusion.choices.face_parser_models) + group_face_masker.add_argument('--face-mask-types', help = wording.get('help.face_mask_types').format(choices = ', '.join(facefusion.choices.face_mask_types)), default = config.get_str_list('face_masker', 'face_mask_types', 'box'), choices = facefusion.choices.face_mask_types, nargs = '+', metavar = 'FACE_MASK_TYPES') + group_face_masker.add_argument('--face-mask-areas', help = wording.get('help.face_mask_areas').format(choices = ', '.join(facefusion.choices.face_mask_areas)), default = config.get_str_list('face_masker', 'face_mask_areas', ' '.join(facefusion.choices.face_mask_areas)), choices = facefusion.choices.face_mask_areas, nargs = '+', metavar = 'FACE_MASK_AREAS') + group_face_masker.add_argument('--face-mask-regions', help = wording.get('help.face_mask_regions').format(choices = ', '.join(facefusion.choices.face_mask_regions)), default = config.get_str_list('face_masker', 'face_mask_regions', ' '.join(facefusion.choices.face_mask_regions)), choices = facefusion.choices.face_mask_regions, nargs = '+', metavar = 'FACE_MASK_REGIONS') + group_face_masker.add_argument('--face-mask-blur', help = wording.get('help.face_mask_blur'), type = float, default = config.get_float_value('face_masker', 'face_mask_blur', '0.3'), choices = facefusion.choices.face_mask_blur_range, metavar = create_float_metavar(facefusion.choices.face_mask_blur_range)) + group_face_masker.add_argument('--face-mask-padding', help = wording.get('help.face_mask_padding'), type = int, default = config.get_int_list('face_masker', 'face_mask_padding', '0 0 0 0'), nargs = '+') + job_store.register_step_keys([ 'face_occluder_model', 'face_parser_model', 'face_mask_types', 'face_mask_areas', 'face_mask_regions', 'face_mask_blur', 'face_mask_padding' ]) + return program + + +def create_frame_extraction_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_frame_extraction = program.add_argument_group('frame extraction') + group_frame_extraction.add_argument('--trim-frame-start', help = wording.get('help.trim_frame_start'), type = int, default = facefusion.config.get_int_value('frame_extraction', 'trim_frame_start')) + group_frame_extraction.add_argument('--trim-frame-end', help = wording.get('help.trim_frame_end'), type = int, default = facefusion.config.get_int_value('frame_extraction', 'trim_frame_end')) + group_frame_extraction.add_argument('--temp-frame-format', help = wording.get('help.temp_frame_format'), default = config.get_str_value('frame_extraction', 'temp_frame_format', 'png'), choices = facefusion.choices.temp_frame_formats) + group_frame_extraction.add_argument('--keep-temp', help = wording.get('help.keep_temp'), action = 'store_true', default = config.get_bool_value('frame_extraction', 'keep_temp')) + job_store.register_step_keys([ 'trim_frame_start', 'trim_frame_end', 'temp_frame_format', 'keep_temp' ]) + return program + + +def create_output_creation_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + available_encoder_set = get_available_encoder_set() + group_output_creation = program.add_argument_group('output creation') + group_output_creation.add_argument('--output-image-quality', help = wording.get('help.output_image_quality'), type = int, default = config.get_int_value('output_creation', 'output_image_quality', '80'), choices = facefusion.choices.output_image_quality_range, metavar = create_int_metavar(facefusion.choices.output_image_quality_range)) + group_output_creation.add_argument('--output-image-resolution', help = wording.get('help.output_image_resolution'), default = config.get_str_value('output_creation', 'output_image_resolution')) + group_output_creation.add_argument('--output-audio-encoder', help = wording.get('help.output_audio_encoder'), default = config.get_str_value('output_creation', 'output_audio_encoder', get_first(available_encoder_set.get('audio'))), choices = available_encoder_set.get('audio')) + group_output_creation.add_argument('--output-audio-quality', help = wording.get('help.output_audio_quality'), type = int, default = config.get_int_value('output_creation', 'output_audio_quality', '80'), choices = facefusion.choices.output_audio_quality_range, metavar = create_int_metavar(facefusion.choices.output_audio_quality_range)) + group_output_creation.add_argument('--output-audio-volume', help = wording.get('help.output_audio_volume'), type = int, default = config.get_int_value('output_creation', 'output_audio_volume', '100'), choices = facefusion.choices.output_audio_volume_range, metavar = create_int_metavar(facefusion.choices.output_audio_volume_range)) + group_output_creation.add_argument('--output-video-encoder', help = wording.get('help.output_video_encoder'), default = config.get_str_value('output_creation', 'output_video_encoder', get_first(available_encoder_set.get('video'))), choices = available_encoder_set.get('video')) + group_output_creation.add_argument('--output-video-preset', help = wording.get('help.output_video_preset'), default = config.get_str_value('output_creation', 'output_video_preset', 'veryfast'), choices = facefusion.choices.output_video_presets) + group_output_creation.add_argument('--output-video-quality', help = wording.get('help.output_video_quality'), type = int, default = config.get_int_value('output_creation', 'output_video_quality', '80'), choices = facefusion.choices.output_video_quality_range, metavar = create_int_metavar(facefusion.choices.output_video_quality_range)) + group_output_creation.add_argument('--output-video-resolution', help = wording.get('help.output_video_resolution'), default = config.get_str_value('output_creation', 'output_video_resolution')) + group_output_creation.add_argument('--output-video-fps', help = wording.get('help.output_video_fps'), type = float, default = config.get_str_value('output_creation', 'output_video_fps')) + job_store.register_step_keys([ 'output_image_quality', 'output_image_resolution', 'output_audio_encoder', 'output_audio_quality', 'output_audio_volume', 'output_video_encoder', 'output_video_preset', 'output_video_quality', 'output_video_resolution', 'output_video_fps' ]) + return program + + +def create_processors_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] + group_processors = program.add_argument_group('processors') + group_processors.add_argument('--processors', help = wording.get('help.processors').format(choices = ', '.join(available_processors)), default = config.get_str_list('processors', 'processors', 'face_swapper'), nargs = '+') + job_store.register_step_keys([ 'processors' ]) + for processor_module in get_processors_modules(available_processors): + processor_module.register_args(program) + return program + + +def create_uis_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + available_ui_layouts = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/uis/layouts') ] + group_uis = program.add_argument_group('uis') + group_uis.add_argument('--open-browser', help = wording.get('help.open_browser'), action = 'store_true', default = config.get_bool_value('uis', 'open_browser')) + group_uis.add_argument('--ui-layouts', help = wording.get('help.ui_layouts').format(choices = ', '.join(available_ui_layouts)), default = config.get_str_list('uis', 'ui_layouts', 'default'), nargs = '+') + group_uis.add_argument('--ui-workflow', help = wording.get('help.ui_workflow'), default = config.get_str_value('uis', 'ui_workflow', 'instant_runner'), choices = facefusion.choices.ui_workflows) + return program + + +def create_download_providers_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_download = program.add_argument_group('download') + group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(facefusion.choices.download_providers)), default = config.get_str_list('download', 'download_providers', ' '.join(facefusion.choices.download_providers)), choices = facefusion.choices.download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS') + job_store.register_job_keys([ 'download_providers' ]) + return program + + +def create_download_scope_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_download = program.add_argument_group('download') + group_download.add_argument('--download-scope', help = wording.get('help.download_scope'), default = config.get_str_value('download', 'download_scope', 'lite'), choices = facefusion.choices.download_scopes) + job_store.register_job_keys([ 'download_scope' ]) + return program + + +def create_benchmark_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_benchmark = program.add_argument_group('benchmark') + group_benchmark.add_argument('--benchmark-resolutions', help = wording.get('help.benchmark_resolutions'), default = config.get_str_list('benchmark', 'benchmark_resolutions', get_first(facefusion.choices.benchmark_resolutions)), choices = facefusion.choices.benchmark_resolutions, nargs = '+') + group_benchmark.add_argument('--benchmark-cycle-count', help = wording.get('help.benchmark_cycle_count'), type = int, default = config.get_int_value('benchmark', 'benchmark_cycle_count', '5'), choices = facefusion.choices.benchmark_cycle_count_range) + return program + + +def create_execution_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + available_execution_providers = get_available_execution_providers() + group_execution = program.add_argument_group('execution') + group_execution.add_argument('--execution-device-id', help = wording.get('help.execution_device_id'), default = config.get_str_value('execution', 'execution_device_id', '0')) + group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution', 'execution_providers', get_first(available_execution_providers)), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS') + group_execution.add_argument('--execution-thread-count', help = wording.get('help.execution_thread_count'), type = int, default = config.get_int_value('execution', 'execution_thread_count', '4'), choices = facefusion.choices.execution_thread_count_range, metavar = create_int_metavar(facefusion.choices.execution_thread_count_range)) + group_execution.add_argument('--execution-queue-count', help = wording.get('help.execution_queue_count'), type = int, default = config.get_int_value('execution', 'execution_queue_count', '1'), choices = facefusion.choices.execution_queue_count_range, metavar = create_int_metavar(facefusion.choices.execution_queue_count_range)) + job_store.register_job_keys([ 'execution_device_id', 'execution_providers', 'execution_thread_count', 'execution_queue_count' ]) + return program + + +def create_memory_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_memory = program.add_argument_group('memory') + group_memory.add_argument('--video-memory-strategy', help = wording.get('help.video_memory_strategy'), default = config.get_str_value('memory', 'video_memory_strategy', 'strict'), choices = facefusion.choices.video_memory_strategies) + group_memory.add_argument('--system-memory-limit', help = wording.get('help.system_memory_limit'), type = int, default = config.get_int_value('memory', 'system_memory_limit', '0'), choices = facefusion.choices.system_memory_limit_range, metavar = create_int_metavar(facefusion.choices.system_memory_limit_range)) + job_store.register_job_keys([ 'video_memory_strategy', 'system_memory_limit' ]) + return program + + +def create_log_level_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_misc = program.add_argument_group('misc') + group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc', 'log_level', 'info'), choices = facefusion.choices.log_levels) + job_store.register_job_keys([ 'log_level' ]) + return program + + +def create_halt_on_error_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + group_misc = program.add_argument_group('misc') + group_misc.add_argument('--halt-on-error', help = wording.get('help.halt_on_error'), action = 'store_true', default = config.get_bool_value('misc', 'halt_on_error')) + job_store.register_job_keys([ 'halt_on_error' ]) + return program + + +def create_job_id_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + program.add_argument('job_id', help = wording.get('help.job_id')) + job_store.register_job_keys([ 'job_id' ]) + return program + + +def create_job_status_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + program.add_argument('job_status', help = wording.get('help.job_status'), choices = facefusion.choices.job_statuses) + return program + + +def create_step_index_program() -> ArgumentParser: + program = ArgumentParser(add_help = False) + program.add_argument('step_index', help = wording.get('help.step_index'), type = int) + return program + + +def collect_step_program() -> ArgumentParser: + return ArgumentParser(parents = [ create_face_detector_program(), create_face_landmarker_program(), create_face_selector_program(), create_face_masker_program(), create_frame_extraction_program(), create_output_creation_program(), create_processors_program() ], add_help = False) + + +def collect_job_program() -> ArgumentParser: + return ArgumentParser(parents = [ create_execution_program(), create_download_providers_program(), create_memory_program(), create_log_level_program() ], add_help = False) + + +def create_program() -> ArgumentParser: + program = ArgumentParser(formatter_class = create_help_formatter_large, add_help = False) + program._positionals.title = 'commands' + program.add_argument('-v', '--version', version = metadata.get('name') + ' ' + metadata.get('version'), action = 'version') + sub_program = program.add_subparsers(dest = 'command') + # general + sub_program.add_parser('run', help = wording.get('help.run'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_source_paths_program(), create_target_path_program(), create_output_path_program(), collect_step_program(), create_uis_program(), collect_job_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('headless-run', help = wording.get('help.headless_run'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_source_paths_program(), create_target_path_program(), create_output_path_program(), collect_step_program(), collect_job_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('batch-run', help = wording.get('help.batch_run'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), create_source_pattern_program(), create_target_pattern_program(), create_output_pattern_program(), collect_step_program(), collect_job_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('force-download', help = wording.get('help.force_download'), parents = [ create_download_providers_program(), create_download_scope_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('benchmark', help = wording.get('help.benchmark'), parents = [ create_temp_path_program(), collect_step_program(), create_benchmark_program(), collect_job_program() ], formatter_class = create_help_formatter_large) + # job manager + sub_program.add_parser('job-list', help = wording.get('help.job_list'), parents = [ create_job_status_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-create', help = wording.get('help.job_create'), parents = [ create_job_id_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-submit', help = wording.get('help.job_submit'), parents = [ create_job_id_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-submit-all', help = wording.get('help.job_submit_all'), parents = [ create_jobs_path_program(), create_log_level_program(), create_halt_on_error_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-delete', help = wording.get('help.job_delete'), parents = [ create_job_id_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-delete-all', help = wording.get('help.job_delete_all'), parents = [ create_jobs_path_program(), create_log_level_program(), create_halt_on_error_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-add-step', help = wording.get('help.job_add_step'), parents = [ create_job_id_program(), create_config_path_program(), create_jobs_path_program(), create_source_paths_program(), create_target_path_program(), create_output_path_program(), collect_step_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-remix-step', help = wording.get('help.job_remix_step'), parents = [ create_job_id_program(), create_step_index_program(), create_config_path_program(), create_jobs_path_program(), create_source_paths_program(), create_output_path_program(), collect_step_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-insert-step', help = wording.get('help.job_insert_step'), parents = [ create_job_id_program(), create_step_index_program(), create_config_path_program(), create_jobs_path_program(), create_source_paths_program(), create_target_path_program(), create_output_path_program(), collect_step_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-remove-step', help = wording.get('help.job_remove_step'), parents = [ create_job_id_program(), create_step_index_program(), create_jobs_path_program(), create_log_level_program() ], formatter_class = create_help_formatter_large) + # job runner + sub_program.add_parser('job-run', help = wording.get('help.job_run'), parents = [ create_job_id_program(), create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), collect_job_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-run-all', help = wording.get('help.job_run_all'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), collect_job_program(), create_halt_on_error_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-retry', help = wording.get('help.job_retry'), parents = [ create_job_id_program(), create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), collect_job_program() ], formatter_class = create_help_formatter_large) + sub_program.add_parser('job-retry-all', help = wording.get('help.job_retry_all'), parents = [ create_config_path_program(), create_temp_path_program(), create_jobs_path_program(), collect_job_program(), create_halt_on_error_program() ], formatter_class = create_help_formatter_large) + return ArgumentParser(parents = [ program ], formatter_class = create_help_formatter_small) + + +def apply_config_path(program : ArgumentParser) -> None: + known_args, _ = program.parse_known_args() + state_manager.init_item('config_path', known_args.config_path) diff --git a/facefusion/program_helper.py b/facefusion/program_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..bec9bf1e307b2b119f32392cc98556f503780583 --- /dev/null +++ b/facefusion/program_helper.py @@ -0,0 +1,31 @@ +from argparse import ArgumentParser, _ArgumentGroup, _SubParsersAction +from typing import Optional + + +def find_argument_group(program : ArgumentParser, group_name : str) -> Optional[_ArgumentGroup]: + for group in program._action_groups: + if group.title == group_name: + return group + return None + + +def validate_args(program : ArgumentParser) -> bool: + if validate_actions(program): + for action in program._actions: + if isinstance(action, _SubParsersAction): + for _, sub_program in action._name_parser_map.items(): + if not validate_args(sub_program): + return False + return True + return False + + +def validate_actions(program : ArgumentParser) -> bool: + for action in program._actions: + if action.default and action.choices: + if isinstance(action.default, list): + if any(default not in action.choices for default in action.default): + return False + elif action.default not in action.choices: + return False + return True diff --git a/facefusion/state_manager.py b/facefusion/state_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..aba6c575c3baf5f14f98b9d1aa28c37469f5feda --- /dev/null +++ b/facefusion/state_manager.py @@ -0,0 +1,38 @@ +from typing import Any, Union + +from facefusion.app_context import detect_app_context +from facefusion.processors.types import ProcessorState, ProcessorStateKey, ProcessorStateSet +from facefusion.types import State, StateKey, StateSet + +STATE_SET : Union[StateSet, ProcessorStateSet] =\ +{ + 'cli': {}, #type:ignore[assignment] + 'ui': {} #type:ignore[assignment] +} + + +def get_state() -> Union[State, ProcessorState]: + app_context = detect_app_context() + return STATE_SET.get(app_context) + + +def init_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None: + STATE_SET['cli'][key] = value #type:ignore[literal-required] + STATE_SET['ui'][key] = value #type:ignore[literal-required] + + +def get_item(key : Union[StateKey, ProcessorStateKey]) -> Any: + return get_state().get(key) #type:ignore[literal-required] + + +def set_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None: + app_context = detect_app_context() + STATE_SET[app_context][key] = value #type:ignore[literal-required] + + +def sync_item(key : Union[StateKey, ProcessorStateKey]) -> None: + STATE_SET['cli'][key] = STATE_SET.get('ui').get(key) #type:ignore[literal-required] + + +def clear_item(key : Union[StateKey, ProcessorStateKey]) -> None: + set_item(key, None) diff --git a/facefusion/temp_helper.py b/facefusion/temp_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..9622621b5ef51db084f8f2eb1d3a2787d9ce2940 --- /dev/null +++ b/facefusion/temp_helper.py @@ -0,0 +1,43 @@ +import os +from typing import List + +from facefusion import state_manager +from facefusion.filesystem import create_directory, get_file_extension, get_file_name, move_file, remove_directory, resolve_file_pattern + + +def get_temp_file_path(file_path : str) -> str: + temp_directory_path = get_temp_directory_path(file_path) + temp_file_extension = get_file_extension(file_path) + return os.path.join(temp_directory_path, 'temp' + temp_file_extension) + + +def move_temp_file(file_path : str, move_path : str) -> bool: + temp_file_path = get_temp_file_path(file_path) + return move_file(temp_file_path, move_path) + + +def resolve_temp_frame_paths(target_path : str) -> List[str]: + temp_frames_pattern = get_temp_frames_pattern(target_path, '*') + return resolve_file_pattern(temp_frames_pattern) + + +def get_temp_frames_pattern(target_path : str, temp_frame_prefix : str) -> str: + temp_directory_path = get_temp_directory_path(target_path) + return os.path.join(temp_directory_path, temp_frame_prefix + '.' + state_manager.get_item('temp_frame_format')) + + +def get_temp_directory_path(file_path : str) -> str: + temp_file_name = get_file_name(file_path) + return os.path.join(state_manager.get_item('temp_path'), 'facefusion', temp_file_name) + + +def create_temp_directory(file_path : str) -> bool: + temp_directory_path = get_temp_directory_path(file_path) + return create_directory(temp_directory_path) + + +def clear_temp_directory(file_path : str) -> bool: + if not state_manager.get_item('keep_temp'): + temp_directory_path = get_temp_directory_path(file_path) + return remove_directory(temp_directory_path) + return True diff --git a/facefusion/thread_helper.py b/facefusion/thread_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..84717f9d95bacc493ca4481f9ae6ea0341ca872a --- /dev/null +++ b/facefusion/thread_helper.py @@ -0,0 +1,23 @@ +import threading +from contextlib import nullcontext +from typing import ContextManager, Union + +from facefusion.execution import has_execution_provider + +THREAD_LOCK : threading.Lock = threading.Lock() +THREAD_SEMAPHORE : threading.Semaphore = threading.Semaphore() +NULL_CONTEXT : ContextManager[None] = nullcontext() + + +def thread_lock() -> threading.Lock: + return THREAD_LOCK + + +def thread_semaphore() -> threading.Semaphore: + return THREAD_SEMAPHORE + + +def conditional_thread_semaphore() -> Union[threading.Semaphore, ContextManager[None]]: + if has_execution_provider('directml') or has_execution_provider('rocm'): + return THREAD_SEMAPHORE + return NULL_CONTEXT diff --git a/facefusion/types.py b/facefusion/types.py new file mode 100644 index 0000000000000000000000000000000000000000..45be13fd725f0becaf57307731a6f8686927d22c --- /dev/null +++ b/facefusion/types.py @@ -0,0 +1,379 @@ +from collections import namedtuple +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeAlias, TypedDict + +import cv2 +import numpy +from numpy.typing import NDArray +from onnxruntime import InferenceSession + +Scale : TypeAlias = float +Score : TypeAlias = float +Angle : TypeAlias = int + +Detection : TypeAlias = NDArray[Any] +Prediction : TypeAlias = NDArray[Any] + +BoundingBox : TypeAlias = NDArray[Any] +FaceLandmark5 : TypeAlias = NDArray[Any] +FaceLandmark68 : TypeAlias = NDArray[Any] +FaceLandmarkSet = TypedDict('FaceLandmarkSet', +{ + '5' : FaceLandmark5, #type:ignore[valid-type] + '5/68' : FaceLandmark5, #type:ignore[valid-type] + '68' : FaceLandmark68, #type:ignore[valid-type] + '68/5' : FaceLandmark68 #type:ignore[valid-type] +}) +FaceScoreSet = TypedDict('FaceScoreSet', +{ + 'detector' : Score, + 'landmarker' : Score +}) +Embedding : TypeAlias = NDArray[numpy.float64] +Gender = Literal['female', 'male'] +Age : TypeAlias = range +Race = Literal['white', 'black', 'latino', 'asian', 'indian', 'arabic'] +Face = namedtuple('Face', +[ + 'bounding_box', + 'score_set', + 'landmark_set', + 'angle', + 'embedding', + 'normed_embedding', + 'gender', + 'age', + 'race' +]) +FaceSet : TypeAlias = Dict[str, List[Face]] +FaceStore = TypedDict('FaceStore', +{ + 'static_faces' : FaceSet, + 'reference_faces' : FaceSet +}) +VideoPoolSet : TypeAlias = Dict[str, cv2.VideoCapture] + +VisionFrame : TypeAlias = NDArray[Any] +Mask : TypeAlias = NDArray[Any] +Points : TypeAlias = NDArray[Any] +Distance : TypeAlias = NDArray[Any] +Matrix : TypeAlias = NDArray[Any] +Anchors : TypeAlias = NDArray[Any] +Translation : TypeAlias = NDArray[Any] + +AudioBuffer : TypeAlias = bytes +Audio : TypeAlias = NDArray[Any] +AudioChunk : TypeAlias = NDArray[Any] +AudioFrame : TypeAlias = NDArray[Any] +Spectrogram : TypeAlias = NDArray[Any] +Mel : TypeAlias = NDArray[Any] +MelFilterBank : TypeAlias = NDArray[Any] + +Fps : TypeAlias = float +Duration : TypeAlias = float +Padding : TypeAlias = Tuple[int, int, int, int] +Orientation = Literal['landscape', 'portrait'] +Resolution : TypeAlias = Tuple[int, int] + +ProcessState = Literal['checking', 'processing', 'stopping', 'pending'] +QueuePayload = TypedDict('QueuePayload', +{ + 'frame_number' : int, + 'frame_path' : str +}) +Args : TypeAlias = Dict[str, Any] +UpdateProgress : TypeAlias = Callable[[int], None] +ProcessFrames : TypeAlias = Callable[[List[str], List[QueuePayload], UpdateProgress], None] +ProcessStep : TypeAlias = Callable[[str, int, Args], bool] + +Content : TypeAlias = Dict[str, Any] + +Commands : TypeAlias = List[str] + +WarpTemplate = Literal['arcface_112_v1', 'arcface_112_v2', 'arcface_128', 'dfl_whole_face', 'ffhq_512', 'mtcnn_512', 'styleganex_384'] +WarpTemplateSet : TypeAlias = Dict[WarpTemplate, NDArray[Any]] +ProcessMode = Literal['output', 'preview', 'stream'] + +ErrorCode = Literal[0, 1, 2, 3, 4] +LogLevel = Literal['error', 'warn', 'info', 'debug'] +LogLevelSet : TypeAlias = Dict[LogLevel, int] + +TableHeaders = List[str] +TableContents = List[List[Any]] + +FaceDetectorModel = Literal['many', 'retinaface', 'scrfd', 'yolo_face'] +FaceLandmarkerModel = Literal['many', '2dfan4', 'peppa_wutz'] +FaceDetectorSet : TypeAlias = Dict[FaceDetectorModel, List[str]] +FaceSelectorMode = Literal['many', 'one', 'reference'] +FaceSelectorOrder = Literal['left-right', 'right-left', 'top-bottom', 'bottom-top', 'small-large', 'large-small', 'best-worst', 'worst-best'] +FaceOccluderModel = Literal['xseg_1', 'xseg_2', 'xseg_3'] +FaceParserModel = Literal['bisenet_resnet_18', 'bisenet_resnet_34'] +FaceMaskType = Literal['box', 'occlusion', 'area', 'region'] +FaceMaskArea = Literal['upper-face', 'lower-face', 'mouth'] +FaceMaskRegion = Literal['skin', 'left-eyebrow', 'right-eyebrow', 'left-eye', 'right-eye', 'glasses', 'nose', 'mouth', 'upper-lip', 'lower-lip'] +FaceMaskRegionSet : TypeAlias = Dict[FaceMaskRegion, int] +FaceMaskAreaSet : TypeAlias = Dict[FaceMaskArea, List[int]] + +AudioFormat = Literal['flac', 'm4a', 'mp3', 'ogg', 'opus', 'wav'] +ImageFormat = Literal['bmp', 'jpeg', 'png', 'tiff', 'webp'] +VideoFormat = Literal['avi', 'm4v', 'mkv', 'mov', 'mp4', 'webm'] +TempFrameFormat = Literal['bmp', 'jpeg', 'png', 'tiff'] +AudioTypeSet : TypeAlias = Dict[AudioFormat, str] +ImageTypeSet : TypeAlias = Dict[ImageFormat, str] +VideoTypeSet : TypeAlias = Dict[VideoFormat, str] + +AudioEncoder = Literal['flac', 'aac', 'libmp3lame', 'libopus', 'libvorbis', 'pcm_s16le', 'pcm_s32le'] +VideoEncoder = Literal['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc', 'h264_amf', 'hevc_amf', 'h264_qsv', 'hevc_qsv', 'h264_videotoolbox', 'hevc_videotoolbox', 'rawvideo'] +EncoderSet = TypedDict('EncoderSet', +{ + 'audio' : List[AudioEncoder], + 'video' : List[VideoEncoder] +}) +VideoPreset = Literal['ultrafast', 'superfast', 'veryfast', 'faster', 'fast', 'medium', 'slow', 'slower', 'veryslow'] + +BenchmarkResolution = Literal['240p', '360p', '540p', '720p', '1080p', '1440p', '2160p'] +BenchmarkSet : TypeAlias = Dict[BenchmarkResolution, str] +BenchmarkCycleSet = TypedDict('BenchmarkCycleSet', +{ + 'target_path' : str, + 'cycle_count' : int, + 'average_run' : float, + 'fastest_run' : float, + 'slowest_run' : float, + 'relative_fps' : float +}) + +WebcamMode = Literal['inline', 'udp', 'v4l2'] +StreamMode = Literal['udp', 'v4l2'] + +ModelOptions : TypeAlias = Dict[str, Any] +ModelSet : TypeAlias = Dict[str, ModelOptions] +ModelInitializer : TypeAlias = NDArray[Any] + +ExecutionProvider = Literal['cpu', 'coreml', 'cuda', 'directml', 'openvino', 'rocm', 'tensorrt'] +ExecutionProviderValue = Literal['CPUExecutionProvider', 'CoreMLExecutionProvider', 'CUDAExecutionProvider', 'DmlExecutionProvider', 'OpenVINOExecutionProvider', 'ROCMExecutionProvider', 'TensorrtExecutionProvider'] +ExecutionProviderSet : TypeAlias = Dict[ExecutionProvider, ExecutionProviderValue] +InferenceSessionProvider : TypeAlias = Any +ValueAndUnit = TypedDict('ValueAndUnit', +{ + 'value' : int, + 'unit' : str +}) +ExecutionDeviceFramework = TypedDict('ExecutionDeviceFramework', +{ + 'name' : str, + 'version' : str +}) +ExecutionDeviceProduct = TypedDict('ExecutionDeviceProduct', +{ + 'vendor' : str, + 'name' : str +}) +ExecutionDeviceVideoMemory = TypedDict('ExecutionDeviceVideoMemory', +{ + 'total' : Optional[ValueAndUnit], + 'free' : Optional[ValueAndUnit] +}) +ExecutionDeviceTemperature = TypedDict('ExecutionDeviceTemperature', +{ + 'gpu' : Optional[ValueAndUnit], + 'memory' : Optional[ValueAndUnit] +}) +ExecutionDeviceUtilization = TypedDict('ExecutionDeviceUtilization', +{ + 'gpu' : Optional[ValueAndUnit], + 'memory' : Optional[ValueAndUnit] +}) +ExecutionDevice = TypedDict('ExecutionDevice', +{ + 'driver_version' : str, + 'framework' : ExecutionDeviceFramework, + 'product' : ExecutionDeviceProduct, + 'video_memory' : ExecutionDeviceVideoMemory, + 'temperature': ExecutionDeviceTemperature, + 'utilization' : ExecutionDeviceUtilization +}) + +DownloadProvider = Literal['github', 'huggingface'] +DownloadProviderValue = TypedDict('DownloadProviderValue', +{ + 'urls' : List[str], + 'path' : str +}) +DownloadProviderSet : TypeAlias = Dict[DownloadProvider, DownloadProviderValue] +DownloadScope = Literal['lite', 'full'] +Download = TypedDict('Download', +{ + 'url' : str, + 'path' : str +}) +DownloadSet : TypeAlias = Dict[str, Download] + +VideoMemoryStrategy = Literal['strict', 'moderate', 'tolerant'] +AppContext = Literal['cli', 'ui'] + +InferencePool : TypeAlias = Dict[str, InferenceSession] +InferencePoolSet : TypeAlias = Dict[AppContext, Dict[str, InferencePool]] + +UiWorkflow = Literal['instant_runner', 'job_runner', 'job_manager'] + +JobStore = TypedDict('JobStore', +{ + 'job_keys' : List[str], + 'step_keys' : List[str] +}) +JobOutputSet : TypeAlias = Dict[str, List[str]] +JobStatus = Literal['drafted', 'queued', 'completed', 'failed'] +JobStepStatus = Literal['drafted', 'queued', 'started', 'completed', 'failed'] +JobStep = TypedDict('JobStep', +{ + 'args' : Args, + 'status' : JobStepStatus +}) +Job = TypedDict('Job', +{ + 'version' : str, + 'date_created' : str, + 'date_updated' : Optional[str], + 'steps' : List[JobStep] +}) +JobSet : TypeAlias = Dict[str, Job] + +StateKey = Literal\ +[ + 'command', + 'config_path', + 'temp_path', + 'jobs_path', + 'source_paths', + 'target_path', + 'output_path', + 'source_pattern', + 'target_pattern', + 'output_pattern', + 'download_providers', + 'download_scope', + 'benchmark_resolutions', + 'benchmark_cycle_count', + 'face_detector_model', + 'face_detector_size', + 'face_detector_angles', + 'face_detector_score', + 'face_landmarker_model', + 'face_landmarker_score', + 'face_selector_mode', + 'face_selector_order', + 'face_selector_gender', + 'face_selector_race', + 'face_selector_age_start', + 'face_selector_age_end', + 'reference_face_position', + 'reference_face_distance', + 'reference_frame_number', + 'face_occluder_model', + 'face_parser_model', + 'face_mask_types', + 'face_mask_areas', + 'face_mask_regions', + 'face_mask_blur', + 'face_mask_padding', + 'trim_frame_start', + 'trim_frame_end', + 'temp_frame_format', + 'keep_temp', + 'output_image_quality', + 'output_image_resolution', + 'output_audio_encoder', + 'output_audio_quality', + 'output_audio_volume', + 'output_video_encoder', + 'output_video_preset', + 'output_video_quality', + 'output_video_resolution', + 'output_video_fps', + 'processors', + 'open_browser', + 'ui_layouts', + 'ui_workflow', + 'execution_device_id', + 'execution_providers', + 'execution_thread_count', + 'execution_queue_count', + 'video_memory_strategy', + 'system_memory_limit', + 'log_level', + 'halt_on_error', + 'job_id', + 'job_status', + 'step_index' +] +State = TypedDict('State', +{ + 'command' : str, + 'config_path' : str, + 'temp_path' : str, + 'jobs_path' : str, + 'source_paths' : List[str], + 'target_path' : str, + 'output_path' : str, + 'source_pattern' : str, + 'target_pattern' : str, + 'output_pattern' : str, + 'download_providers': List[DownloadProvider], + 'download_scope': DownloadScope, + 'benchmark_resolutions': List[BenchmarkResolution], + 'benchmark_cycle_count': int, + 'face_detector_model' : FaceDetectorModel, + 'face_detector_size' : str, + 'face_detector_angles' : List[Angle], + 'face_detector_score' : Score, + 'face_landmarker_model' : FaceLandmarkerModel, + 'face_landmarker_score' : Score, + 'face_selector_mode' : FaceSelectorMode, + 'face_selector_order' : FaceSelectorOrder, + 'face_selector_race' : Race, + 'face_selector_gender' : Gender, + 'face_selector_age_start' : int, + 'face_selector_age_end' : int, + 'reference_face_position' : int, + 'reference_face_distance' : float, + 'reference_frame_number' : int, + 'face_occluder_model' : FaceOccluderModel, + 'face_parser_model' : FaceParserModel, + 'face_mask_types' : List[FaceMaskType], + 'face_mask_areas' : List[FaceMaskArea], + 'face_mask_regions' : List[FaceMaskRegion], + 'face_mask_blur' : float, + 'face_mask_padding' : Padding, + 'trim_frame_start' : int, + 'trim_frame_end' : int, + 'temp_frame_format' : TempFrameFormat, + 'keep_temp' : bool, + 'output_image_quality' : int, + 'output_image_resolution' : str, + 'output_audio_encoder' : AudioEncoder, + 'output_audio_quality' : int, + 'output_audio_volume' : int, + 'output_video_encoder' : VideoEncoder, + 'output_video_preset' : VideoPreset, + 'output_video_quality' : int, + 'output_video_resolution' : str, + 'output_video_fps' : float, + 'processors' : List[str], + 'open_browser' : bool, + 'ui_layouts' : List[str], + 'ui_workflow' : UiWorkflow, + 'execution_device_id' : str, + 'execution_providers' : List[ExecutionProvider], + 'execution_thread_count' : int, + 'execution_queue_count' : int, + 'video_memory_strategy' : VideoMemoryStrategy, + 'system_memory_limit' : int, + 'log_level' : LogLevel, + 'halt_on_error' : bool, + 'job_id' : str, + 'job_status' : JobStatus, + 'step_index' : int +}) +ApplyStateItem : TypeAlias = Callable[[Any, Any], None] +StateSet : TypeAlias = Dict[AppContext, State] + diff --git a/facefusion/uis/__init__.py b/facefusion/uis/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/facefusion/uis/assets/overrides.css b/facefusion/uis/assets/overrides.css new file mode 100644 index 0000000000000000000000000000000000000000..d6f06b2617fe064f2341b8c50fced794a47effff --- /dev/null +++ b/facefusion/uis/assets/overrides.css @@ -0,0 +1,156 @@ +:root:root:root:root .gradio-container +{ + overflow: unset; +} + +:root:root:root:root main +{ + max-width: 110em; +} + +:root:root:root:root input[type="number"] +{ + appearance: textfield; + border-radius: unset; + text-align: center; + order: 1; + padding: unset +} + +:root:root:root:root input[type="number"]::-webkit-inner-spin-button +{ + appearance: none; +} + +:root:root:root:root input[type="number"]:focus +{ + outline: unset; +} + +:root:root:root:root .reset-button +{ + background: var(--background-fill-secondary); + border: unset; + font-size: unset; + padding: unset; +} + +:root:root:root:root [type="checkbox"], +:root:root:root:root [type="radio"] +{ + border-radius: 50%; + height: 1.125rem; + width: 1.125rem; +} + +:root:root:root:root input[type="range"] +{ + background: transparent; +} + +:root:root:root:root input[type="range"]::-moz-range-thumb, +:root:root:root:root input[type="range"]::-webkit-slider-thumb +{ + background: var(--neutral-300); + box-shadow: unset; + border-radius: 50%; + height: 1.125rem; + width: 1.125rem; +} + +:root:root:root:root .thumbnail-item +{ + border: unset; + box-shadow: unset; +} + +:root:root:root:root .grid-wrap.fixed-height +{ + min-height: unset; +} + +:root:root:root:root .box-face-selector .empty, +:root:root:root:root .box-face-selector .gallery-container +{ + min-height: 7.375rem; +} + +:root:root:root:root .tab-wrapper +{ + padding: 0 0.625rem; +} + +:root:root:root:root .tab-container +{ + gap: 0.5em; +} + +:root:root:root:root .tab-container button +{ + background: unset; + border-bottom: 0.125rem solid; +} + +:root:root:root:root .tab-container button.selected +{ + color: var(--primary-500) +} + +:root:root:root:root .toast-body +{ + background: white; + color: var(--primary-500); + border: unset; + border-radius: unset; +} + +:root:root:root:root .dark .toast-body +{ + background: var(--neutral-900); + color: var(--primary-600); +} + +:root:root:root:root .toast-icon, +:root:root:root:root .toast-title, +:root:root:root:root .toast-text, +:root:root:root:root .toast-close +{ + color: unset; +} + +:root:root:root:root .toast-body .timer +{ + background: currentColor; +} + +:root:root:root:root .slider_input_container > span, +:root:root:root:root .feather-upload, +:root:root:root:root footer +{ + display: none; +} + +:root:root:root:root .image-frame +{ + width: 100%; +} + +:root:root:root:root .image-frame > img +{ + object-fit: cover; +} + +:root:root:root:root .image-preview.is-landscape +{ + position: sticky; + top: 0; + z-index: 100; +} + +:root:root:root:root .block .error +{ + border: 0.125rem solid; + padding: 0.375rem 0.75rem; + font-size: 0.75rem; + text-transform: uppercase; +} diff --git a/facefusion/uis/choices.py b/facefusion/uis/choices.py new file mode 100644 index 0000000000000000000000000000000000000000..228a67b59e9f84bc13cb5bf404ac56552c15af9b --- /dev/null +++ b/facefusion/uis/choices.py @@ -0,0 +1,9 @@ +from typing import List + +from facefusion.uis.types import JobManagerAction, JobRunnerAction + +job_manager_actions : List[JobManagerAction] = [ 'job-create', 'job-submit', 'job-delete', 'job-add-step', 'job-remix-step', 'job-insert-step', 'job-remove-step' ] +job_runner_actions : List[JobRunnerAction] = [ 'job-run', 'job-run-all', 'job-retry', 'job-retry-all' ] + +common_options : List[str] = [ 'keep-temp' ] + diff --git a/facefusion/uis/components/__init__.py b/facefusion/uis/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/facefusion/uis/components/about.py b/facefusion/uis/components/about.py new file mode 100644 index 0000000000000000000000000000000000000000..fedaba726e52e1276edc86a69ae32888f7fe4d1e --- /dev/null +++ b/facefusion/uis/components/about.py @@ -0,0 +1,41 @@ +import random +from typing import Optional + +import gradio + +from facefusion import metadata, wording + +METADATA_BUTTON : Optional[gradio.Button] = None +ACTION_BUTTON : Optional[gradio.Button] = None + + +def render() -> None: + global METADATA_BUTTON + global ACTION_BUTTON + + action = random.choice( + [ + { + 'wording': wording.get('about.become_a_member'), + 'url': 'https://subscribe.facefusion.io' + }, + { + 'wording': wording.get('about.join_our_community'), + 'url': 'https://join.facefusion.io' + }, + { + 'wording': wording.get('about.read_the_documentation'), + 'url': 'https://docs.facefusion.io' + } + ]) + + METADATA_BUTTON = gradio.Button( + value = metadata.get('name') + ' ' + metadata.get('version'), + variant = 'primary', + link = metadata.get('url') + ) + ACTION_BUTTON = gradio.Button( + value = action.get('wording'), + link = action.get('url'), + size = 'sm' + ) diff --git a/facefusion/uis/components/age_modifier_options.py b/facefusion/uis/components/age_modifier_options.py new file mode 100644 index 0000000000000000000000000000000000000000..e42065e5b0a5cca320ec9bb20058e2399b3601ed --- /dev/null +++ b/facefusion/uis/components/age_modifier_options.py @@ -0,0 +1,64 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import calc_float_step +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import AgeModifierModel +from facefusion.uis.core import get_ui_component, register_ui_component + +AGE_MODIFIER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +AGE_MODIFIER_DIRECTION_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global AGE_MODIFIER_MODEL_DROPDOWN + global AGE_MODIFIER_DIRECTION_SLIDER + + has_age_modifier = 'age_modifier' in state_manager.get_item('processors') + AGE_MODIFIER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.age_modifier_model_dropdown'), + choices = processors_choices.age_modifier_models, + value = state_manager.get_item('age_modifier_model'), + visible = has_age_modifier + ) + AGE_MODIFIER_DIRECTION_SLIDER = gradio.Slider( + label = wording.get('uis.age_modifier_direction_slider'), + value = state_manager.get_item('age_modifier_direction'), + step = calc_float_step(processors_choices.age_modifier_direction_range), + minimum = processors_choices.age_modifier_direction_range[0], + maximum = processors_choices.age_modifier_direction_range[-1], + visible = has_age_modifier + ) + register_ui_component('age_modifier_model_dropdown', AGE_MODIFIER_MODEL_DROPDOWN) + register_ui_component('age_modifier_direction_slider', AGE_MODIFIER_DIRECTION_SLIDER) + + +def listen() -> None: + AGE_MODIFIER_MODEL_DROPDOWN.change(update_age_modifier_model, inputs = AGE_MODIFIER_MODEL_DROPDOWN, outputs = AGE_MODIFIER_MODEL_DROPDOWN) + AGE_MODIFIER_DIRECTION_SLIDER.release(update_age_modifier_direction, inputs = AGE_MODIFIER_DIRECTION_SLIDER) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ AGE_MODIFIER_MODEL_DROPDOWN, AGE_MODIFIER_DIRECTION_SLIDER ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider]: + has_age_modifier = 'age_modifier' in processors + return gradio.Dropdown(visible = has_age_modifier), gradio.Slider(visible = has_age_modifier) + + +def update_age_modifier_model(age_modifier_model : AgeModifierModel) -> gradio.Dropdown: + age_modifier_module = load_processor_module('age_modifier') + age_modifier_module.clear_inference_pool() + state_manager.set_item('age_modifier_model', age_modifier_model) + + if age_modifier_module.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('age_modifier_model')) + return gradio.Dropdown() + + +def update_age_modifier_direction(age_modifier_direction : float) -> None: + state_manager.set_item('age_modifier_direction', int(age_modifier_direction)) diff --git a/facefusion/uis/components/benchmark.py b/facefusion/uis/components/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..d002920b3a7e9a412e05326c7a87c2497a204c5a --- /dev/null +++ b/facefusion/uis/components/benchmark.py @@ -0,0 +1,61 @@ +from typing import Any, Generator, List, Optional + +import gradio + +from facefusion import benchmarker, state_manager, wording +from facefusion.types import BenchmarkResolution +from facefusion.uis.core import get_ui_component + +BENCHMARK_BENCHMARKS_DATAFRAME : Optional[gradio.Dataframe] = None +BENCHMARK_START_BUTTON : Optional[gradio.Button] = None + + +def render() -> None: + global BENCHMARK_BENCHMARKS_DATAFRAME + global BENCHMARK_START_BUTTON + + BENCHMARK_BENCHMARKS_DATAFRAME = gradio.Dataframe( + headers = + [ + 'target_path', + 'cycle_count', + 'average_run', + 'fastest_run', + 'slowest_run', + 'relative_fps' + ], + datatype = + [ + 'str', + 'number', + 'number', + 'number', + 'number', + 'number' + ], + show_label = False + ) + BENCHMARK_START_BUTTON = gradio.Button( + value = wording.get('uis.start_button'), + variant = 'primary', + size = 'sm' + ) + + +def listen() -> None: + benchmark_resolutions_checkbox_group = get_ui_component('benchmark_resolutions_checkbox_group') + benchmark_cycle_count_slider = get_ui_component('benchmark_cycle_count_slider') + + if benchmark_resolutions_checkbox_group and benchmark_cycle_count_slider: + BENCHMARK_START_BUTTON.click(start, inputs = [ benchmark_resolutions_checkbox_group, benchmark_cycle_count_slider ], outputs = BENCHMARK_BENCHMARKS_DATAFRAME) + + +def start(benchmark_resolutions : List[BenchmarkResolution], benchmark_cycle_count : int) -> Generator[List[Any], None, None]: + state_manager.set_item('benchmark_resolutions', benchmark_resolutions) + state_manager.set_item('benchmark_cycle_count', benchmark_cycle_count) + state_manager.sync_item('execution_providers') + state_manager.sync_item('execution_thread_count') + state_manager.sync_item('execution_queue_count') + + for benchmark in benchmarker.run(): + yield [ list(benchmark_set.values()) for benchmark_set in benchmark ] diff --git a/facefusion/uis/components/benchmark_options.py b/facefusion/uis/components/benchmark_options.py new file mode 100644 index 0000000000000000000000000000000000000000..549084a8d259885cd8dda0d889f78aaa30c5f898 --- /dev/null +++ b/facefusion/uis/components/benchmark_options.py @@ -0,0 +1,30 @@ +from typing import Optional + +import gradio + +import facefusion.choices +from facefusion import wording +from facefusion.uis.core import register_ui_component + +BENCHMARK_RESOLUTIONS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None +BENCHMARK_CYCLE_COUNT_SLIDER : Optional[gradio.Button] = None + + +def render() -> None: + global BENCHMARK_RESOLUTIONS_CHECKBOX_GROUP + global BENCHMARK_CYCLE_COUNT_SLIDER + + BENCHMARK_RESOLUTIONS_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.benchmark_resolutions_checkbox_group'), + choices = facefusion.choices.benchmark_resolutions, + value = facefusion.choices.benchmark_resolutions + ) + BENCHMARK_CYCLE_COUNT_SLIDER = gradio.Slider( + label = wording.get('uis.benchmark_cycle_count_slider'), + value = 5, + step = 1, + minimum = min(facefusion.choices.benchmark_cycle_count_range), + maximum = max(facefusion.choices.benchmark_cycle_count_range) + ) + register_ui_component('benchmark_resolutions_checkbox_group', BENCHMARK_RESOLUTIONS_CHECKBOX_GROUP) + register_ui_component('benchmark_cycle_count_slider', BENCHMARK_CYCLE_COUNT_SLIDER) diff --git a/facefusion/uis/components/common_options.py b/facefusion/uis/components/common_options.py new file mode 100644 index 0000000000000000000000000000000000000000..1cf96fa9e41bd75e2dedd3a4d2ab57c2c48bffb7 --- /dev/null +++ b/facefusion/uis/components/common_options.py @@ -0,0 +1,32 @@ +from typing import List, Optional + +import gradio + +from facefusion import state_manager, wording +from facefusion.uis import choices as uis_choices + +COMMON_OPTIONS_CHECKBOX_GROUP : Optional[gradio.Checkboxgroup] = None + + +def render() -> None: + global COMMON_OPTIONS_CHECKBOX_GROUP + + common_options = [] + + if state_manager.get_item('keep_temp'): + common_options.append('keep-temp') + + COMMON_OPTIONS_CHECKBOX_GROUP = gradio.Checkboxgroup( + label = wording.get('uis.common_options_checkbox_group'), + choices = uis_choices.common_options, + value = common_options + ) + + +def listen() -> None: + COMMON_OPTIONS_CHECKBOX_GROUP.change(update, inputs = COMMON_OPTIONS_CHECKBOX_GROUP) + + +def update(common_options : List[str]) -> None: + keep_temp = 'keep-temp' in common_options + state_manager.set_item('keep_temp', keep_temp) diff --git a/facefusion/uis/components/deep_swapper_options.py b/facefusion/uis/components/deep_swapper_options.py new file mode 100644 index 0000000000000000000000000000000000000000..210193d5ac9894cf19dcef6e21f8e9009869516c --- /dev/null +++ b/facefusion/uis/components/deep_swapper_options.py @@ -0,0 +1,64 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import calc_int_step +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import DeepSwapperModel +from facefusion.uis.core import get_ui_component, register_ui_component + +DEEP_SWAPPER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +DEEP_SWAPPER_MORPH_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global DEEP_SWAPPER_MODEL_DROPDOWN + global DEEP_SWAPPER_MORPH_SLIDER + + has_deep_swapper = 'deep_swapper' in state_manager.get_item('processors') + DEEP_SWAPPER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.deep_swapper_model_dropdown'), + choices = processors_choices.deep_swapper_models, + value = state_manager.get_item('deep_swapper_model'), + visible = has_deep_swapper + ) + DEEP_SWAPPER_MORPH_SLIDER = gradio.Slider( + label = wording.get('uis.deep_swapper_morph_slider'), + value = state_manager.get_item('deep_swapper_morph'), + step = calc_int_step(processors_choices.deep_swapper_morph_range), + minimum = processors_choices.deep_swapper_morph_range[0], + maximum = processors_choices.deep_swapper_morph_range[-1], + visible = has_deep_swapper and load_processor_module('deep_swapper').get_inference_pool() and load_processor_module('deep_swapper').has_morph_input() + ) + register_ui_component('deep_swapper_model_dropdown', DEEP_SWAPPER_MODEL_DROPDOWN) + register_ui_component('deep_swapper_morph_slider', DEEP_SWAPPER_MORPH_SLIDER) + + +def listen() -> None: + DEEP_SWAPPER_MODEL_DROPDOWN.change(update_deep_swapper_model, inputs = DEEP_SWAPPER_MODEL_DROPDOWN, outputs = [ DEEP_SWAPPER_MODEL_DROPDOWN, DEEP_SWAPPER_MORPH_SLIDER ]) + DEEP_SWAPPER_MORPH_SLIDER.release(update_deep_swapper_morph, inputs = DEEP_SWAPPER_MORPH_SLIDER) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ DEEP_SWAPPER_MODEL_DROPDOWN, DEEP_SWAPPER_MORPH_SLIDER ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider]: + has_deep_swapper = 'deep_swapper' in processors + return gradio.Dropdown(visible = has_deep_swapper), gradio.Slider(visible = has_deep_swapper and load_processor_module('deep_swapper').get_inference_pool() and load_processor_module('deep_swapper').has_morph_input()) + + +def update_deep_swapper_model(deep_swapper_model : DeepSwapperModel) -> Tuple[gradio.Dropdown, gradio.Slider]: + deep_swapper_module = load_processor_module('deep_swapper') + deep_swapper_module.clear_inference_pool() + state_manager.set_item('deep_swapper_model', deep_swapper_model) + + if deep_swapper_module.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('deep_swapper_model')), gradio.Slider(visible = deep_swapper_module.has_morph_input()) + return gradio.Dropdown(), gradio.Slider() + + +def update_deep_swapper_morph(deep_swapper_morph : int) -> None: + state_manager.set_item('deep_swapper_morph', deep_swapper_morph) diff --git a/facefusion/uis/components/download.py b/facefusion/uis/components/download.py new file mode 100644 index 0000000000000000000000000000000000000000..547e2ba39a8c09ec861de32e3ab14fa879be55c1 --- /dev/null +++ b/facefusion/uis/components/download.py @@ -0,0 +1,48 @@ +from typing import List, Optional + +import gradio + +import facefusion.choices +from facefusion import content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, state_manager, voice_extractor, wording +from facefusion.filesystem import get_file_name, resolve_file_paths +from facefusion.processors.core import get_processors_modules +from facefusion.types import DownloadProvider + +DOWNLOAD_PROVIDERS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None + + +def render() -> None: + global DOWNLOAD_PROVIDERS_CHECKBOX_GROUP + + DOWNLOAD_PROVIDERS_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.download_providers_checkbox_group'), + choices = facefusion.choices.download_providers, + value = state_manager.get_item('download_providers') + ) + + +def listen() -> None: + DOWNLOAD_PROVIDERS_CHECKBOX_GROUP.change(update_download_providers, inputs = DOWNLOAD_PROVIDERS_CHECKBOX_GROUP, outputs = DOWNLOAD_PROVIDERS_CHECKBOX_GROUP) + + +def update_download_providers(download_providers : List[DownloadProvider]) -> gradio.CheckboxGroup: + common_modules =\ + [ + content_analyser, + face_classifier, + face_detector, + face_landmarker, + face_recognizer, + face_masker, + voice_extractor + ] + available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] + processor_modules = get_processors_modules(available_processors) + + for module in common_modules + processor_modules: + if hasattr(module, 'create_static_model_set'): + module.create_static_model_set.cache_clear() + + download_providers = download_providers or facefusion.choices.download_providers + state_manager.set_item('download_providers', download_providers) + return gradio.CheckboxGroup(value = state_manager.get_item('download_providers')) diff --git a/facefusion/uis/components/execution.py b/facefusion/uis/components/execution.py new file mode 100644 index 0000000000000000000000000000000000000000..4be6eafcce6e8fb42e8b4700f49f1b59c2ad0ab8 --- /dev/null +++ b/facefusion/uis/components/execution.py @@ -0,0 +1,48 @@ +from typing import List, Optional + +import gradio + +from facefusion import content_analyser, face_classifier, face_detector, face_landmarker, face_masker, face_recognizer, state_manager, voice_extractor, wording +from facefusion.execution import get_available_execution_providers +from facefusion.filesystem import get_file_name, resolve_file_paths +from facefusion.processors.core import get_processors_modules +from facefusion.types import ExecutionProvider + +EXECUTION_PROVIDERS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None + + +def render() -> None: + global EXECUTION_PROVIDERS_CHECKBOX_GROUP + + EXECUTION_PROVIDERS_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.execution_providers_checkbox_group'), + choices = get_available_execution_providers(), + value = state_manager.get_item('execution_providers') + ) + + +def listen() -> None: + EXECUTION_PROVIDERS_CHECKBOX_GROUP.change(update_execution_providers, inputs = EXECUTION_PROVIDERS_CHECKBOX_GROUP, outputs = EXECUTION_PROVIDERS_CHECKBOX_GROUP) + + +def update_execution_providers(execution_providers : List[ExecutionProvider]) -> gradio.CheckboxGroup: + common_modules =\ + [ + content_analyser, + face_classifier, + face_detector, + face_landmarker, + face_masker, + face_recognizer, + voice_extractor + ] + available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] + processor_modules = get_processors_modules(available_processors) + + for module in common_modules + processor_modules: + if hasattr(module, 'clear_inference_pool'): + module.clear_inference_pool() + + execution_providers = execution_providers or get_available_execution_providers() + state_manager.set_item('execution_providers', execution_providers) + return gradio.CheckboxGroup(value = state_manager.get_item('execution_providers')) diff --git a/facefusion/uis/components/execution_queue_count.py b/facefusion/uis/components/execution_queue_count.py new file mode 100644 index 0000000000000000000000000000000000000000..b5ab5dadf998c7d98d88cf6148b3a0e63c09b977 --- /dev/null +++ b/facefusion/uis/components/execution_queue_count.py @@ -0,0 +1,29 @@ +from typing import Optional + +import gradio + +import facefusion.choices +from facefusion import state_manager, wording +from facefusion.common_helper import calc_int_step + +EXECUTION_QUEUE_COUNT_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global EXECUTION_QUEUE_COUNT_SLIDER + + EXECUTION_QUEUE_COUNT_SLIDER = gradio.Slider( + label = wording.get('uis.execution_queue_count_slider'), + value = state_manager.get_item('execution_queue_count'), + step = calc_int_step(facefusion.choices.execution_queue_count_range), + minimum = facefusion.choices.execution_queue_count_range[0], + maximum = facefusion.choices.execution_queue_count_range[-1] + ) + + +def listen() -> None: + EXECUTION_QUEUE_COUNT_SLIDER.release(update_execution_queue_count, inputs = EXECUTION_QUEUE_COUNT_SLIDER) + + +def update_execution_queue_count(execution_queue_count : float) -> None: + state_manager.set_item('execution_queue_count', int(execution_queue_count)) diff --git a/facefusion/uis/components/execution_thread_count.py b/facefusion/uis/components/execution_thread_count.py new file mode 100644 index 0000000000000000000000000000000000000000..f5716a99f55a90730e3b4faefc7b6e2a4783ba89 --- /dev/null +++ b/facefusion/uis/components/execution_thread_count.py @@ -0,0 +1,29 @@ +from typing import Optional + +import gradio + +import facefusion.choices +from facefusion import state_manager, wording +from facefusion.common_helper import calc_int_step + +EXECUTION_THREAD_COUNT_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global EXECUTION_THREAD_COUNT_SLIDER + + EXECUTION_THREAD_COUNT_SLIDER = gradio.Slider( + label = wording.get('uis.execution_thread_count_slider'), + value = state_manager.get_item('execution_thread_count'), + step = calc_int_step(facefusion.choices.execution_thread_count_range), + minimum = facefusion.choices.execution_thread_count_range[0], + maximum = facefusion.choices.execution_thread_count_range[-1] + ) + + +def listen() -> None: + EXECUTION_THREAD_COUNT_SLIDER.release(update_execution_thread_count, inputs = EXECUTION_THREAD_COUNT_SLIDER) + + +def update_execution_thread_count(execution_thread_count : float) -> None: + state_manager.set_item('execution_thread_count', int(execution_thread_count)) diff --git a/facefusion/uis/components/expression_restorer_options.py b/facefusion/uis/components/expression_restorer_options.py new file mode 100644 index 0000000000000000000000000000000000000000..bf5eec4e29fcddb4287aa416170d5eee43bbabb7 --- /dev/null +++ b/facefusion/uis/components/expression_restorer_options.py @@ -0,0 +1,64 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import calc_float_step +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import ExpressionRestorerModel +from facefusion.uis.core import get_ui_component, register_ui_component + +EXPRESSION_RESTORER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +EXPRESSION_RESTORER_FACTOR_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global EXPRESSION_RESTORER_MODEL_DROPDOWN + global EXPRESSION_RESTORER_FACTOR_SLIDER + + has_expression_restorer = 'expression_restorer' in state_manager.get_item('processors') + EXPRESSION_RESTORER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.expression_restorer_model_dropdown'), + choices = processors_choices.expression_restorer_models, + value = state_manager.get_item('expression_restorer_model'), + visible = has_expression_restorer + ) + EXPRESSION_RESTORER_FACTOR_SLIDER = gradio.Slider( + label = wording.get('uis.expression_restorer_factor_slider'), + value = state_manager.get_item('expression_restorer_factor'), + step = calc_float_step(processors_choices.expression_restorer_factor_range), + minimum = processors_choices.expression_restorer_factor_range[0], + maximum = processors_choices.expression_restorer_factor_range[-1], + visible = has_expression_restorer + ) + register_ui_component('expression_restorer_model_dropdown', EXPRESSION_RESTORER_MODEL_DROPDOWN) + register_ui_component('expression_restorer_factor_slider', EXPRESSION_RESTORER_FACTOR_SLIDER) + + +def listen() -> None: + EXPRESSION_RESTORER_MODEL_DROPDOWN.change(update_expression_restorer_model, inputs = EXPRESSION_RESTORER_MODEL_DROPDOWN, outputs = EXPRESSION_RESTORER_MODEL_DROPDOWN) + EXPRESSION_RESTORER_FACTOR_SLIDER.release(update_expression_restorer_factor, inputs = EXPRESSION_RESTORER_FACTOR_SLIDER) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ EXPRESSION_RESTORER_MODEL_DROPDOWN, EXPRESSION_RESTORER_FACTOR_SLIDER ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider]: + has_expression_restorer = 'expression_restorer' in processors + return gradio.Dropdown(visible = has_expression_restorer), gradio.Slider(visible = has_expression_restorer) + + +def update_expression_restorer_model(expression_restorer_model : ExpressionRestorerModel) -> gradio.Dropdown: + expression_restorer_module = load_processor_module('expression_restorer') + expression_restorer_module.clear_inference_pool() + state_manager.set_item('expression_restorer_model', expression_restorer_model) + + if expression_restorer_module.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('expression_restorer_model')) + return gradio.Dropdown() + + +def update_expression_restorer_factor(expression_restorer_factor : float) -> None: + state_manager.set_item('expression_restorer_factor', int(expression_restorer_factor)) diff --git a/facefusion/uis/components/face_debugger_options.py b/facefusion/uis/components/face_debugger_options.py new file mode 100644 index 0000000000000000000000000000000000000000..032eb24ea01fc5e25add699cbccb69c6798c5a47 --- /dev/null +++ b/facefusion/uis/components/face_debugger_options.py @@ -0,0 +1,40 @@ +from typing import List, Optional + +import gradio + +from facefusion import state_manager, wording +from facefusion.processors import choices as processors_choices +from facefusion.processors.types import FaceDebuggerItem +from facefusion.uis.core import get_ui_component, register_ui_component + +FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None + + +def render() -> None: + global FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP + + has_face_debugger = 'face_debugger' in state_manager.get_item('processors') + FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.face_debugger_items_checkbox_group'), + choices = processors_choices.face_debugger_items, + value = state_manager.get_item('face_debugger_items'), + visible = has_face_debugger + ) + register_ui_component('face_debugger_items_checkbox_group', FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP) + + +def listen() -> None: + FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP.change(update_face_debugger_items, inputs = FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = FACE_DEBUGGER_ITEMS_CHECKBOX_GROUP) + + +def remote_update(processors : List[str]) -> gradio.CheckboxGroup: + has_face_debugger = 'face_debugger' in processors + return gradio.CheckboxGroup(visible = has_face_debugger) + + +def update_face_debugger_items(face_debugger_items : List[FaceDebuggerItem]) -> None: + state_manager.set_item('face_debugger_items', face_debugger_items) diff --git a/facefusion/uis/components/face_detector.py b/facefusion/uis/components/face_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..756154d9e6ce9303f12ad2a9d31a5d581785d564 --- /dev/null +++ b/facefusion/uis/components/face_detector.py @@ -0,0 +1,85 @@ +from typing import Optional, Sequence, Tuple + +import gradio + +import facefusion.choices +from facefusion import face_detector, state_manager, wording +from facefusion.common_helper import calc_float_step, get_last +from facefusion.types import Angle, FaceDetectorModel, Score +from facefusion.uis.core import register_ui_component +from facefusion.uis.types import ComponentOptions + +FACE_DETECTOR_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_DETECTOR_SIZE_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_DETECTOR_ANGLES_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None +FACE_DETECTOR_SCORE_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global FACE_DETECTOR_MODEL_DROPDOWN + global FACE_DETECTOR_SIZE_DROPDOWN + global FACE_DETECTOR_ANGLES_CHECKBOX_GROUP + global FACE_DETECTOR_SCORE_SLIDER + + face_detector_size_dropdown_options : ComponentOptions =\ + { + 'label': wording.get('uis.face_detector_size_dropdown'), + 'value': state_manager.get_item('face_detector_size') + } + if state_manager.get_item('face_detector_size') in facefusion.choices.face_detector_set[state_manager.get_item('face_detector_model')]: + face_detector_size_dropdown_options['choices'] = facefusion.choices.face_detector_set[state_manager.get_item('face_detector_model')] + with gradio.Row(): + FACE_DETECTOR_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_detector_model_dropdown'), + choices = facefusion.choices.face_detector_models, + value = state_manager.get_item('face_detector_model') + ) + FACE_DETECTOR_SIZE_DROPDOWN = gradio.Dropdown(**face_detector_size_dropdown_options) + FACE_DETECTOR_ANGLES_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.face_detector_angles_checkbox_group'), + choices = facefusion.choices.face_detector_angles, + value = state_manager.get_item('face_detector_angles') + ) + FACE_DETECTOR_SCORE_SLIDER = gradio.Slider( + label = wording.get('uis.face_detector_score_slider'), + value = state_manager.get_item('face_detector_score'), + step = calc_float_step(facefusion.choices.face_detector_score_range), + minimum = facefusion.choices.face_detector_score_range[0], + maximum = facefusion.choices.face_detector_score_range[-1] + ) + register_ui_component('face_detector_model_dropdown', FACE_DETECTOR_MODEL_DROPDOWN) + register_ui_component('face_detector_size_dropdown', FACE_DETECTOR_SIZE_DROPDOWN) + register_ui_component('face_detector_angles_checkbox_group', FACE_DETECTOR_ANGLES_CHECKBOX_GROUP) + register_ui_component('face_detector_score_slider', FACE_DETECTOR_SCORE_SLIDER) + + +def listen() -> None: + FACE_DETECTOR_MODEL_DROPDOWN.change(update_face_detector_model, inputs = FACE_DETECTOR_MODEL_DROPDOWN, outputs = [ FACE_DETECTOR_MODEL_DROPDOWN, FACE_DETECTOR_SIZE_DROPDOWN ]) + FACE_DETECTOR_SIZE_DROPDOWN.change(update_face_detector_size, inputs = FACE_DETECTOR_SIZE_DROPDOWN) + FACE_DETECTOR_ANGLES_CHECKBOX_GROUP.change(update_face_detector_angles, inputs = FACE_DETECTOR_ANGLES_CHECKBOX_GROUP, outputs = FACE_DETECTOR_ANGLES_CHECKBOX_GROUP) + FACE_DETECTOR_SCORE_SLIDER.release(update_face_detector_score, inputs = FACE_DETECTOR_SCORE_SLIDER) + + +def update_face_detector_model(face_detector_model : FaceDetectorModel) -> Tuple[gradio.Dropdown, gradio.Dropdown]: + face_detector.clear_inference_pool() + state_manager.set_item('face_detector_model', face_detector_model) + + if face_detector.pre_check(): + face_detector_size_choices = facefusion.choices.face_detector_set.get(state_manager.get_item('face_detector_model')) + state_manager.set_item('face_detector_size', get_last(face_detector_size_choices)) + return gradio.Dropdown(value = state_manager.get_item('face_detector_model')), gradio.Dropdown(value = state_manager.get_item('face_detector_size'), choices = face_detector_size_choices) + return gradio.Dropdown(), gradio.Dropdown() + + +def update_face_detector_size(face_detector_size : str) -> None: + state_manager.set_item('face_detector_size', face_detector_size) + + +def update_face_detector_angles(face_detector_angles : Sequence[Angle]) -> gradio.CheckboxGroup: + face_detector_angles = face_detector_angles or facefusion.choices.face_detector_angles + state_manager.set_item('face_detector_angles', face_detector_angles) + return gradio.CheckboxGroup(value = state_manager.get_item('face_detector_angles')) + + +def update_face_detector_score(face_detector_score : Score) -> None: + state_manager.set_item('face_detector_score', face_detector_score) diff --git a/facefusion/uis/components/face_editor_options.py b/facefusion/uis/components/face_editor_options.py new file mode 100644 index 0000000000000000000000000000000000000000..978b12d81f3d14f1fbc716c67b382fba8c17b935 --- /dev/null +++ b/facefusion/uis/components/face_editor_options.py @@ -0,0 +1,272 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import calc_float_step +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import FaceEditorModel +from facefusion.uis.core import get_ui_component, register_ui_component + +FACE_EDITOR_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_EDITOR_EYEBROW_DIRECTION_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_EYE_GAZE_HORIZONTAL_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_EYE_GAZE_VERTICAL_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_EYE_OPEN_RATIO_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_LIP_OPEN_RATIO_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_MOUTH_GRIM_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_MOUTH_POUT_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_MOUTH_PURSE_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_MOUTH_SMILE_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_MOUTH_POSITION_HORIZONTAL_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_MOUTH_POSITION_VERTICAL_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_HEAD_PITCH_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_HEAD_YAW_SLIDER : Optional[gradio.Slider] = None +FACE_EDITOR_HEAD_ROLL_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global FACE_EDITOR_MODEL_DROPDOWN + global FACE_EDITOR_EYEBROW_DIRECTION_SLIDER + global FACE_EDITOR_EYE_GAZE_HORIZONTAL_SLIDER + global FACE_EDITOR_EYE_GAZE_VERTICAL_SLIDER + global FACE_EDITOR_EYE_OPEN_RATIO_SLIDER + global FACE_EDITOR_LIP_OPEN_RATIO_SLIDER + global FACE_EDITOR_MOUTH_GRIM_SLIDER + global FACE_EDITOR_MOUTH_POUT_SLIDER + global FACE_EDITOR_MOUTH_PURSE_SLIDER + global FACE_EDITOR_MOUTH_SMILE_SLIDER + global FACE_EDITOR_MOUTH_POSITION_HORIZONTAL_SLIDER + global FACE_EDITOR_MOUTH_POSITION_VERTICAL_SLIDER + global FACE_EDITOR_HEAD_PITCH_SLIDER + global FACE_EDITOR_HEAD_YAW_SLIDER + global FACE_EDITOR_HEAD_ROLL_SLIDER + + has_face_editor = 'face_editor' in state_manager.get_item('processors') + FACE_EDITOR_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_editor_model_dropdown'), + choices = processors_choices.face_editor_models, + value = state_manager.get_item('face_editor_model'), + visible = has_face_editor + ) + FACE_EDITOR_EYEBROW_DIRECTION_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_eyebrow_direction_slider'), + value = state_manager.get_item('face_editor_eyebrow_direction'), + step = calc_float_step(processors_choices.face_editor_eyebrow_direction_range), + minimum = processors_choices.face_editor_eyebrow_direction_range[0], + maximum = processors_choices.face_editor_eyebrow_direction_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_EYE_GAZE_HORIZONTAL_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_eye_gaze_horizontal_slider'), + value = state_manager.get_item('face_editor_eye_gaze_horizontal'), + step = calc_float_step(processors_choices.face_editor_eye_gaze_horizontal_range), + minimum = processors_choices.face_editor_eye_gaze_horizontal_range[0], + maximum = processors_choices.face_editor_eye_gaze_horizontal_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_EYE_GAZE_VERTICAL_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_eye_gaze_vertical_slider'), + value = state_manager.get_item('face_editor_eye_gaze_vertical'), + step = calc_float_step(processors_choices.face_editor_eye_gaze_vertical_range), + minimum = processors_choices.face_editor_eye_gaze_vertical_range[0], + maximum = processors_choices.face_editor_eye_gaze_vertical_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_EYE_OPEN_RATIO_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_eye_open_ratio_slider'), + value = state_manager.get_item('face_editor_eye_open_ratio'), + step = calc_float_step(processors_choices.face_editor_eye_open_ratio_range), + minimum = processors_choices.face_editor_eye_open_ratio_range[0], + maximum = processors_choices.face_editor_eye_open_ratio_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_LIP_OPEN_RATIO_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_lip_open_ratio_slider'), + value = state_manager.get_item('face_editor_lip_open_ratio'), + step = calc_float_step(processors_choices.face_editor_lip_open_ratio_range), + minimum = processors_choices.face_editor_lip_open_ratio_range[0], + maximum = processors_choices.face_editor_lip_open_ratio_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_MOUTH_GRIM_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_mouth_grim_slider'), + value = state_manager.get_item('face_editor_mouth_grim'), + step = calc_float_step(processors_choices.face_editor_mouth_grim_range), + minimum = processors_choices.face_editor_mouth_grim_range[0], + maximum = processors_choices.face_editor_mouth_grim_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_MOUTH_POUT_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_mouth_pout_slider'), + value = state_manager.get_item('face_editor_mouth_pout'), + step = calc_float_step(processors_choices.face_editor_mouth_pout_range), + minimum = processors_choices.face_editor_mouth_pout_range[0], + maximum = processors_choices.face_editor_mouth_pout_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_MOUTH_PURSE_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_mouth_purse_slider'), + value = state_manager.get_item('face_editor_mouth_purse'), + step = calc_float_step(processors_choices.face_editor_mouth_purse_range), + minimum = processors_choices.face_editor_mouth_purse_range[0], + maximum = processors_choices.face_editor_mouth_purse_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_MOUTH_SMILE_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_mouth_smile_slider'), + value = state_manager.get_item('face_editor_mouth_smile'), + step = calc_float_step(processors_choices.face_editor_mouth_smile_range), + minimum = processors_choices.face_editor_mouth_smile_range[0], + maximum = processors_choices.face_editor_mouth_smile_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_MOUTH_POSITION_HORIZONTAL_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_mouth_position_horizontal_slider'), + value = state_manager.get_item('face_editor_mouth_position_horizontal'), + step = calc_float_step(processors_choices.face_editor_mouth_position_horizontal_range), + minimum = processors_choices.face_editor_mouth_position_horizontal_range[0], + maximum = processors_choices.face_editor_mouth_position_horizontal_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_MOUTH_POSITION_VERTICAL_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_mouth_position_vertical_slider'), + value = state_manager.get_item('face_editor_mouth_position_vertical'), + step = calc_float_step(processors_choices.face_editor_mouth_position_vertical_range), + minimum = processors_choices.face_editor_mouth_position_vertical_range[0], + maximum = processors_choices.face_editor_mouth_position_vertical_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_HEAD_PITCH_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_head_pitch_slider'), + value = state_manager.get_item('face_editor_head_pitch'), + step = calc_float_step(processors_choices.face_editor_head_pitch_range), + minimum = processors_choices.face_editor_head_pitch_range[0], + maximum = processors_choices.face_editor_head_pitch_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_HEAD_YAW_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_head_yaw_slider'), + value = state_manager.get_item('face_editor_head_yaw'), + step = calc_float_step(processors_choices.face_editor_head_yaw_range), + minimum = processors_choices.face_editor_head_yaw_range[0], + maximum = processors_choices.face_editor_head_yaw_range[-1], + visible = has_face_editor + ) + FACE_EDITOR_HEAD_ROLL_SLIDER = gradio.Slider( + label = wording.get('uis.face_editor_head_roll_slider'), + value = state_manager.get_item('face_editor_head_roll'), + step = calc_float_step(processors_choices.face_editor_head_roll_range), + minimum = processors_choices.face_editor_head_roll_range[0], + maximum = processors_choices.face_editor_head_roll_range[-1], + visible = has_face_editor + ) + register_ui_component('face_editor_model_dropdown', FACE_EDITOR_MODEL_DROPDOWN) + register_ui_component('face_editor_eyebrow_direction_slider', FACE_EDITOR_EYEBROW_DIRECTION_SLIDER) + register_ui_component('face_editor_eye_gaze_horizontal_slider', FACE_EDITOR_EYE_GAZE_HORIZONTAL_SLIDER) + register_ui_component('face_editor_eye_gaze_vertical_slider', FACE_EDITOR_EYE_GAZE_VERTICAL_SLIDER) + register_ui_component('face_editor_eye_open_ratio_slider', FACE_EDITOR_EYE_OPEN_RATIO_SLIDER) + register_ui_component('face_editor_lip_open_ratio_slider', FACE_EDITOR_LIP_OPEN_RATIO_SLIDER) + register_ui_component('face_editor_mouth_grim_slider', FACE_EDITOR_MOUTH_GRIM_SLIDER) + register_ui_component('face_editor_mouth_pout_slider', FACE_EDITOR_MOUTH_POUT_SLIDER) + register_ui_component('face_editor_mouth_purse_slider', FACE_EDITOR_MOUTH_PURSE_SLIDER) + register_ui_component('face_editor_mouth_smile_slider', FACE_EDITOR_MOUTH_SMILE_SLIDER) + register_ui_component('face_editor_mouth_position_horizontal_slider', FACE_EDITOR_MOUTH_POSITION_HORIZONTAL_SLIDER) + register_ui_component('face_editor_mouth_position_vertical_slider', FACE_EDITOR_MOUTH_POSITION_VERTICAL_SLIDER) + register_ui_component('face_editor_head_pitch_slider', FACE_EDITOR_HEAD_PITCH_SLIDER) + register_ui_component('face_editor_head_yaw_slider', FACE_EDITOR_HEAD_YAW_SLIDER) + register_ui_component('face_editor_head_roll_slider', FACE_EDITOR_HEAD_ROLL_SLIDER) + + +def listen() -> None: + FACE_EDITOR_MODEL_DROPDOWN.change(update_face_editor_model, inputs = FACE_EDITOR_MODEL_DROPDOWN, outputs = FACE_EDITOR_MODEL_DROPDOWN) + FACE_EDITOR_EYEBROW_DIRECTION_SLIDER.release(update_face_editor_eyebrow_direction, inputs = FACE_EDITOR_EYEBROW_DIRECTION_SLIDER) + FACE_EDITOR_EYE_GAZE_HORIZONTAL_SLIDER.release(update_face_editor_eye_gaze_horizontal, inputs = FACE_EDITOR_EYE_GAZE_HORIZONTAL_SLIDER) + FACE_EDITOR_EYE_GAZE_VERTICAL_SLIDER.release(update_face_editor_eye_gaze_vertical, inputs = FACE_EDITOR_EYE_GAZE_VERTICAL_SLIDER) + FACE_EDITOR_EYE_OPEN_RATIO_SLIDER.release(update_face_editor_eye_open_ratio, inputs = FACE_EDITOR_EYE_OPEN_RATIO_SLIDER) + FACE_EDITOR_LIP_OPEN_RATIO_SLIDER.release(update_face_editor_lip_open_ratio, inputs = FACE_EDITOR_LIP_OPEN_RATIO_SLIDER) + FACE_EDITOR_MOUTH_GRIM_SLIDER.release(update_face_editor_mouth_grim, inputs = FACE_EDITOR_MOUTH_GRIM_SLIDER) + FACE_EDITOR_MOUTH_POUT_SLIDER.release(update_face_editor_mouth_pout, inputs = FACE_EDITOR_MOUTH_POUT_SLIDER) + FACE_EDITOR_MOUTH_PURSE_SLIDER.release(update_face_editor_mouth_purse, inputs = FACE_EDITOR_MOUTH_PURSE_SLIDER) + FACE_EDITOR_MOUTH_SMILE_SLIDER.release(update_face_editor_mouth_smile, inputs = FACE_EDITOR_MOUTH_SMILE_SLIDER) + FACE_EDITOR_MOUTH_POSITION_HORIZONTAL_SLIDER.release(update_face_editor_mouth_position_horizontal, inputs = FACE_EDITOR_MOUTH_POSITION_HORIZONTAL_SLIDER) + FACE_EDITOR_MOUTH_POSITION_VERTICAL_SLIDER.release(update_face_editor_mouth_position_vertical, inputs = FACE_EDITOR_MOUTH_POSITION_VERTICAL_SLIDER) + FACE_EDITOR_HEAD_PITCH_SLIDER.release(update_face_editor_head_pitch, inputs = FACE_EDITOR_HEAD_PITCH_SLIDER) + FACE_EDITOR_HEAD_YAW_SLIDER.release(update_face_editor_head_yaw, inputs = FACE_EDITOR_HEAD_YAW_SLIDER) + FACE_EDITOR_HEAD_ROLL_SLIDER.release(update_face_editor_head_roll, inputs = FACE_EDITOR_HEAD_ROLL_SLIDER) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ FACE_EDITOR_MODEL_DROPDOWN, FACE_EDITOR_EYEBROW_DIRECTION_SLIDER, FACE_EDITOR_EYE_GAZE_HORIZONTAL_SLIDER, FACE_EDITOR_EYE_GAZE_VERTICAL_SLIDER, FACE_EDITOR_EYE_OPEN_RATIO_SLIDER, FACE_EDITOR_LIP_OPEN_RATIO_SLIDER, FACE_EDITOR_MOUTH_GRIM_SLIDER, FACE_EDITOR_MOUTH_POUT_SLIDER, FACE_EDITOR_MOUTH_PURSE_SLIDER, FACE_EDITOR_MOUTH_SMILE_SLIDER, FACE_EDITOR_MOUTH_POSITION_HORIZONTAL_SLIDER, FACE_EDITOR_MOUTH_POSITION_VERTICAL_SLIDER, FACE_EDITOR_HEAD_PITCH_SLIDER, FACE_EDITOR_HEAD_YAW_SLIDER, FACE_EDITOR_HEAD_ROLL_SLIDER ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider]: + has_face_editor = 'face_editor' in processors + return gradio.Dropdown(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor), gradio.Slider(visible = has_face_editor) + + +def update_face_editor_model(face_editor_model : FaceEditorModel) -> gradio.Dropdown: + face_editor_module = load_processor_module('face_editor') + face_editor_module.clear_inference_pool() + state_manager.set_item('face_editor_model', face_editor_model) + + if face_editor_module.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('face_editor_model')) + return gradio.Dropdown() + + +def update_face_editor_eyebrow_direction(face_editor_eyebrow_direction : float) -> None: + state_manager.set_item('face_editor_eyebrow_direction', face_editor_eyebrow_direction) + + +def update_face_editor_eye_gaze_horizontal(face_editor_eye_gaze_horizontal : float) -> None: + state_manager.set_item('face_editor_eye_gaze_horizontal', face_editor_eye_gaze_horizontal) + + +def update_face_editor_eye_gaze_vertical(face_editor_eye_gaze_vertical : float) -> None: + state_manager.set_item('face_editor_eye_gaze_vertical', face_editor_eye_gaze_vertical) + + +def update_face_editor_eye_open_ratio(face_editor_eye_open_ratio : float) -> None: + state_manager.set_item('face_editor_eye_open_ratio', face_editor_eye_open_ratio) + + +def update_face_editor_lip_open_ratio(face_editor_lip_open_ratio : float) -> None: + state_manager.set_item('face_editor_lip_open_ratio', face_editor_lip_open_ratio) + + +def update_face_editor_mouth_grim(face_editor_mouth_grim : float) -> None: + state_manager.set_item('face_editor_mouth_grim', face_editor_mouth_grim) + + +def update_face_editor_mouth_pout(face_editor_mouth_pout : float) -> None: + state_manager.set_item('face_editor_mouth_pout', face_editor_mouth_pout) + + +def update_face_editor_mouth_purse(face_editor_mouth_purse : float) -> None: + state_manager.set_item('face_editor_mouth_purse', face_editor_mouth_purse) + + +def update_face_editor_mouth_smile(face_editor_mouth_smile : float) -> None: + state_manager.set_item('face_editor_mouth_smile', face_editor_mouth_smile) + + +def update_face_editor_mouth_position_horizontal(face_editor_mouth_position_horizontal : float) -> None: + state_manager.set_item('face_editor_mouth_position_horizontal', face_editor_mouth_position_horizontal) + + +def update_face_editor_mouth_position_vertical(face_editor_mouth_position_vertical : float) -> None: + state_manager.set_item('face_editor_mouth_position_vertical', face_editor_mouth_position_vertical) + + +def update_face_editor_head_pitch(face_editor_head_pitch : float) -> None: + state_manager.set_item('face_editor_head_pitch', face_editor_head_pitch) + + +def update_face_editor_head_yaw(face_editor_head_yaw : float) -> None: + state_manager.set_item('face_editor_head_yaw', face_editor_head_yaw) + + +def update_face_editor_head_roll(face_editor_head_roll : float) -> None: + state_manager.set_item('face_editor_head_roll', face_editor_head_roll) diff --git a/facefusion/uis/components/face_enhancer_options.py b/facefusion/uis/components/face_enhancer_options.py new file mode 100644 index 0000000000000000000000000000000000000000..0e02d865a79071f9abd9691dd47321cbaa92d9e6 --- /dev/null +++ b/facefusion/uis/components/face_enhancer_options.py @@ -0,0 +1,81 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import calc_float_step, calc_int_step +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import FaceEnhancerModel +from facefusion.uis.core import get_ui_component, register_ui_component + +FACE_ENHANCER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_ENHANCER_BLEND_SLIDER : Optional[gradio.Slider] = None +FACE_ENHANCER_WEIGHT_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global FACE_ENHANCER_MODEL_DROPDOWN + global FACE_ENHANCER_BLEND_SLIDER + global FACE_ENHANCER_WEIGHT_SLIDER + + has_face_enhancer = 'face_enhancer' in state_manager.get_item('processors') + FACE_ENHANCER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_enhancer_model_dropdown'), + choices = processors_choices.face_enhancer_models, + value = state_manager.get_item('face_enhancer_model'), + visible = has_face_enhancer + ) + FACE_ENHANCER_BLEND_SLIDER = gradio.Slider( + label = wording.get('uis.face_enhancer_blend_slider'), + value = state_manager.get_item('face_enhancer_blend'), + step = calc_int_step(processors_choices.face_enhancer_blend_range), + minimum = processors_choices.face_enhancer_blend_range[0], + maximum = processors_choices.face_enhancer_blend_range[-1], + visible = has_face_enhancer + ) + FACE_ENHANCER_WEIGHT_SLIDER = gradio.Slider( + label = wording.get('uis.face_enhancer_weight_slider'), + value = state_manager.get_item('face_enhancer_weight'), + step = calc_float_step(processors_choices.face_enhancer_weight_range), + minimum = processors_choices.face_enhancer_weight_range[0], + maximum = processors_choices.face_enhancer_weight_range[-1], + visible = has_face_enhancer and load_processor_module('face_enhancer').get_inference_pool() and load_processor_module('face_enhancer').has_weight_input() + ) + register_ui_component('face_enhancer_model_dropdown', FACE_ENHANCER_MODEL_DROPDOWN) + register_ui_component('face_enhancer_blend_slider', FACE_ENHANCER_BLEND_SLIDER) + register_ui_component('face_enhancer_weight_slider', FACE_ENHANCER_WEIGHT_SLIDER) + + +def listen() -> None: + FACE_ENHANCER_MODEL_DROPDOWN.change(update_face_enhancer_model, inputs = FACE_ENHANCER_MODEL_DROPDOWN, outputs = [ FACE_ENHANCER_MODEL_DROPDOWN, FACE_ENHANCER_WEIGHT_SLIDER ]) + FACE_ENHANCER_BLEND_SLIDER.release(update_face_enhancer_blend, inputs = FACE_ENHANCER_BLEND_SLIDER) + FACE_ENHANCER_WEIGHT_SLIDER.release(update_face_enhancer_weight, inputs = FACE_ENHANCER_WEIGHT_SLIDER) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ FACE_ENHANCER_MODEL_DROPDOWN, FACE_ENHANCER_BLEND_SLIDER, FACE_ENHANCER_WEIGHT_SLIDER ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider, gradio.Slider]: + has_face_enhancer = 'face_enhancer' in processors + return gradio.Dropdown(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer and load_processor_module('face_enhancer').get_inference_pool() and load_processor_module('face_enhancer').has_weight_input()) + + +def update_face_enhancer_model(face_enhancer_model : FaceEnhancerModel) -> Tuple[gradio.Dropdown, gradio.Slider]: + face_enhancer_module = load_processor_module('face_enhancer') + face_enhancer_module.clear_inference_pool() + state_manager.set_item('face_enhancer_model', face_enhancer_model) + + if face_enhancer_module.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('face_enhancer_model')), gradio.Slider(visible = face_enhancer_module.has_weight_input()) + return gradio.Dropdown(), gradio.Slider() + + +def update_face_enhancer_blend(face_enhancer_blend : float) -> None: + state_manager.set_item('face_enhancer_blend', int(face_enhancer_blend)) + + +def update_face_enhancer_weight(face_enhancer_weight : float) -> None: + state_manager.set_item('face_enhancer_weight', face_enhancer_weight) + diff --git a/facefusion/uis/components/face_landmarker.py b/facefusion/uis/components/face_landmarker.py new file mode 100644 index 0000000000000000000000000000000000000000..7fab429db32abd82735f229d67ca87a9442f87b1 --- /dev/null +++ b/facefusion/uis/components/face_landmarker.py @@ -0,0 +1,50 @@ +from typing import Optional + +import gradio + +import facefusion.choices +from facefusion import face_landmarker, state_manager, wording +from facefusion.common_helper import calc_float_step +from facefusion.types import FaceLandmarkerModel, Score +from facefusion.uis.core import register_ui_component + +FACE_LANDMARKER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_LANDMARKER_SCORE_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global FACE_LANDMARKER_MODEL_DROPDOWN + global FACE_LANDMARKER_SCORE_SLIDER + + FACE_LANDMARKER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_landmarker_model_dropdown'), + choices = facefusion.choices.face_landmarker_models, + value = state_manager.get_item('face_landmarker_model') + ) + FACE_LANDMARKER_SCORE_SLIDER = gradio.Slider( + label = wording.get('uis.face_landmarker_score_slider'), + value = state_manager.get_item('face_landmarker_score'), + step = calc_float_step(facefusion.choices.face_landmarker_score_range), + minimum = facefusion.choices.face_landmarker_score_range[0], + maximum = facefusion.choices.face_landmarker_score_range[-1] + ) + register_ui_component('face_landmarker_model_dropdown', FACE_LANDMARKER_MODEL_DROPDOWN) + register_ui_component('face_landmarker_score_slider', FACE_LANDMARKER_SCORE_SLIDER) + + +def listen() -> None: + FACE_LANDMARKER_MODEL_DROPDOWN.change(update_face_landmarker_model, inputs = FACE_LANDMARKER_MODEL_DROPDOWN, outputs = FACE_LANDMARKER_MODEL_DROPDOWN) + FACE_LANDMARKER_SCORE_SLIDER.release(update_face_landmarker_score, inputs = FACE_LANDMARKER_SCORE_SLIDER) + + +def update_face_landmarker_model(face_landmarker_model : FaceLandmarkerModel) -> gradio.Dropdown: + face_landmarker.clear_inference_pool() + state_manager.set_item('face_landmarker_model', face_landmarker_model) + + if face_landmarker.pre_check(): + gradio.Dropdown(value = state_manager.get_item('face_landmarker_model')) + return gradio.Dropdown() + + +def update_face_landmarker_score(face_landmarker_score : Score) -> None: + state_manager.set_item('face_landmarker_score', face_landmarker_score) diff --git a/facefusion/uis/components/face_masker.py b/facefusion/uis/components/face_masker.py new file mode 100644 index 0000000000000000000000000000000000000000..e01a5cd772fb6a5f48798b3c681967cc049a7147 --- /dev/null +++ b/facefusion/uis/components/face_masker.py @@ -0,0 +1,179 @@ +from typing import List, Optional, Tuple + +import gradio + +import facefusion.choices +from facefusion import face_masker, state_manager, wording +from facefusion.common_helper import calc_float_step, calc_int_step +from facefusion.types import FaceMaskArea, FaceMaskRegion, FaceMaskType, FaceOccluderModel, FaceParserModel +from facefusion.uis.core import register_ui_component + +FACE_OCCLUDER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_PARSER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_MASK_TYPES_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None +FACE_MASK_AREAS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None +FACE_MASK_REGIONS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None +FACE_MASK_BLUR_SLIDER : Optional[gradio.Slider] = None +FACE_MASK_PADDING_TOP_SLIDER : Optional[gradio.Slider] = None +FACE_MASK_PADDING_RIGHT_SLIDER : Optional[gradio.Slider] = None +FACE_MASK_PADDING_BOTTOM_SLIDER : Optional[gradio.Slider] = None +FACE_MASK_PADDING_LEFT_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global FACE_OCCLUDER_MODEL_DROPDOWN + global FACE_PARSER_MODEL_DROPDOWN + global FACE_MASK_TYPES_CHECKBOX_GROUP + global FACE_MASK_AREAS_CHECKBOX_GROUP + global FACE_MASK_REGIONS_CHECKBOX_GROUP + global FACE_MASK_BLUR_SLIDER + global FACE_MASK_PADDING_TOP_SLIDER + global FACE_MASK_PADDING_RIGHT_SLIDER + global FACE_MASK_PADDING_BOTTOM_SLIDER + global FACE_MASK_PADDING_LEFT_SLIDER + + has_box_mask = 'box' in state_manager.get_item('face_mask_types') + has_region_mask = 'region' in state_manager.get_item('face_mask_types') + has_area_mask = 'area' in state_manager.get_item('face_mask_types') + with gradio.Row(): + FACE_OCCLUDER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_occluder_model_dropdown'), + choices = facefusion.choices.face_occluder_models, + value = state_manager.get_item('face_occluder_model') + ) + FACE_PARSER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_parser_model_dropdown'), + choices = facefusion.choices.face_parser_models, + value = state_manager.get_item('face_parser_model') + ) + FACE_MASK_TYPES_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.face_mask_types_checkbox_group'), + choices = facefusion.choices.face_mask_types, + value = state_manager.get_item('face_mask_types') + ) + FACE_MASK_AREAS_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.face_mask_areas_checkbox_group'), + choices = facefusion.choices.face_mask_areas, + value = state_manager.get_item('face_mask_areas'), + visible = has_area_mask + ) + FACE_MASK_REGIONS_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.face_mask_regions_checkbox_group'), + choices = facefusion.choices.face_mask_regions, + value = state_manager.get_item('face_mask_regions'), + visible = has_region_mask + ) + FACE_MASK_BLUR_SLIDER = gradio.Slider( + label = wording.get('uis.face_mask_blur_slider'), + step = calc_float_step(facefusion.choices.face_mask_blur_range), + minimum = facefusion.choices.face_mask_blur_range[0], + maximum = facefusion.choices.face_mask_blur_range[-1], + value = state_manager.get_item('face_mask_blur'), + visible = has_box_mask + ) + with gradio.Group(): + with gradio.Row(): + FACE_MASK_PADDING_TOP_SLIDER = gradio.Slider( + label = wording.get('uis.face_mask_padding_top_slider'), + step = calc_int_step(facefusion.choices.face_mask_padding_range), + minimum = facefusion.choices.face_mask_padding_range[0], + maximum = facefusion.choices.face_mask_padding_range[-1], + value = state_manager.get_item('face_mask_padding')[0], + visible = has_box_mask + ) + FACE_MASK_PADDING_RIGHT_SLIDER = gradio.Slider( + label = wording.get('uis.face_mask_padding_right_slider'), + step = calc_int_step(facefusion.choices.face_mask_padding_range), + minimum = facefusion.choices.face_mask_padding_range[0], + maximum = facefusion.choices.face_mask_padding_range[-1], + value = state_manager.get_item('face_mask_padding')[1], + visible = has_box_mask + ) + with gradio.Row(): + FACE_MASK_PADDING_BOTTOM_SLIDER = gradio.Slider( + label = wording.get('uis.face_mask_padding_bottom_slider'), + step = calc_int_step(facefusion.choices.face_mask_padding_range), + minimum = facefusion.choices.face_mask_padding_range[0], + maximum = facefusion.choices.face_mask_padding_range[-1], + value = state_manager.get_item('face_mask_padding')[2], + visible = has_box_mask + ) + FACE_MASK_PADDING_LEFT_SLIDER = gradio.Slider( + label = wording.get('uis.face_mask_padding_left_slider'), + step = calc_int_step(facefusion.choices.face_mask_padding_range), + minimum = facefusion.choices.face_mask_padding_range[0], + maximum = facefusion.choices.face_mask_padding_range[-1], + value = state_manager.get_item('face_mask_padding')[3], + visible = has_box_mask + ) + register_ui_component('face_occluder_model_dropdown', FACE_OCCLUDER_MODEL_DROPDOWN) + register_ui_component('face_parser_model_dropdown', FACE_PARSER_MODEL_DROPDOWN) + register_ui_component('face_mask_types_checkbox_group', FACE_MASK_TYPES_CHECKBOX_GROUP) + register_ui_component('face_mask_areas_checkbox_group', FACE_MASK_AREAS_CHECKBOX_GROUP) + register_ui_component('face_mask_regions_checkbox_group', FACE_MASK_REGIONS_CHECKBOX_GROUP) + register_ui_component('face_mask_blur_slider', FACE_MASK_BLUR_SLIDER) + register_ui_component('face_mask_padding_top_slider', FACE_MASK_PADDING_TOP_SLIDER) + register_ui_component('face_mask_padding_right_slider', FACE_MASK_PADDING_RIGHT_SLIDER) + register_ui_component('face_mask_padding_bottom_slider', FACE_MASK_PADDING_BOTTOM_SLIDER) + register_ui_component('face_mask_padding_left_slider', FACE_MASK_PADDING_LEFT_SLIDER) + + +def listen() -> None: + FACE_OCCLUDER_MODEL_DROPDOWN.change(update_face_occluder_model, inputs = FACE_OCCLUDER_MODEL_DROPDOWN) + FACE_PARSER_MODEL_DROPDOWN.change(update_face_parser_model, inputs = FACE_PARSER_MODEL_DROPDOWN) + FACE_MASK_TYPES_CHECKBOX_GROUP.change(update_face_mask_types, inputs = FACE_MASK_TYPES_CHECKBOX_GROUP, outputs = [ FACE_MASK_TYPES_CHECKBOX_GROUP, FACE_MASK_AREAS_CHECKBOX_GROUP, FACE_MASK_REGIONS_CHECKBOX_GROUP, FACE_MASK_BLUR_SLIDER, FACE_MASK_PADDING_TOP_SLIDER, FACE_MASK_PADDING_RIGHT_SLIDER, FACE_MASK_PADDING_BOTTOM_SLIDER, FACE_MASK_PADDING_LEFT_SLIDER ]) + FACE_MASK_AREAS_CHECKBOX_GROUP.change(update_face_mask_areas, inputs = FACE_MASK_AREAS_CHECKBOX_GROUP, outputs = FACE_MASK_AREAS_CHECKBOX_GROUP) + FACE_MASK_REGIONS_CHECKBOX_GROUP.change(update_face_mask_regions, inputs = FACE_MASK_REGIONS_CHECKBOX_GROUP, outputs = FACE_MASK_REGIONS_CHECKBOX_GROUP) + FACE_MASK_BLUR_SLIDER.release(update_face_mask_blur, inputs = FACE_MASK_BLUR_SLIDER) + + face_mask_padding_sliders = [ FACE_MASK_PADDING_TOP_SLIDER, FACE_MASK_PADDING_RIGHT_SLIDER, FACE_MASK_PADDING_BOTTOM_SLIDER, FACE_MASK_PADDING_LEFT_SLIDER ] + for face_mask_padding_slider in face_mask_padding_sliders: + face_mask_padding_slider.release(update_face_mask_padding, inputs = face_mask_padding_sliders) + + +def update_face_occluder_model(face_occluder_model : FaceOccluderModel) -> gradio.Dropdown: + face_masker.clear_inference_pool() + state_manager.set_item('face_occluder_model', face_occluder_model) + + if face_masker.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('face_occluder_model')) + return gradio.Dropdown() + + +def update_face_parser_model(face_parser_model : FaceParserModel) -> gradio.Dropdown: + face_masker.clear_inference_pool() + state_manager.set_item('face_parser_model', face_parser_model) + + if face_masker.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('face_parser_model')) + return gradio.Dropdown() + + +def update_face_mask_types(face_mask_types : List[FaceMaskType]) -> Tuple[gradio.CheckboxGroup, gradio.CheckboxGroup, gradio.CheckboxGroup, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider, gradio.Slider]: + face_mask_types = face_mask_types or facefusion.choices.face_mask_types + state_manager.set_item('face_mask_types', face_mask_types) + has_box_mask = 'box' in face_mask_types + has_area_mask = 'area' in face_mask_types + has_region_mask = 'region' in face_mask_types + return gradio.CheckboxGroup(value = state_manager.get_item('face_mask_types')), gradio.CheckboxGroup(visible = has_area_mask), gradio.CheckboxGroup(visible = has_region_mask), gradio.Slider(visible = has_box_mask), gradio.Slider(visible = has_box_mask), gradio.Slider(visible = has_box_mask), gradio.Slider(visible = has_box_mask), gradio.Slider(visible = has_box_mask) + + +def update_face_mask_areas(face_mask_areas : List[FaceMaskArea]) -> gradio.CheckboxGroup: + face_mask_areas = face_mask_areas or facefusion.choices.face_mask_areas + state_manager.set_item('face_mask_areas', face_mask_areas) + return gradio.CheckboxGroup(value = state_manager.get_item('face_mask_areas')) + + +def update_face_mask_regions(face_mask_regions : List[FaceMaskRegion]) -> gradio.CheckboxGroup: + face_mask_regions = face_mask_regions or facefusion.choices.face_mask_regions + state_manager.set_item('face_mask_regions', face_mask_regions) + return gradio.CheckboxGroup(value = state_manager.get_item('face_mask_regions')) + + +def update_face_mask_blur(face_mask_blur : float) -> None: + state_manager.set_item('face_mask_blur', face_mask_blur) + + +def update_face_mask_padding(face_mask_padding_top : float, face_mask_padding_right : float, face_mask_padding_bottom : float, face_mask_padding_left : float) -> None: + face_mask_padding = (int(face_mask_padding_top), int(face_mask_padding_right), int(face_mask_padding_bottom), int(face_mask_padding_left)) + state_manager.set_item('face_mask_padding', face_mask_padding) diff --git a/facefusion/uis/components/face_selector.py b/facefusion/uis/components/face_selector.py new file mode 100644 index 0000000000000000000000000000000000000000..58f21c845f1dc05f11135a19d746164559570c7d --- /dev/null +++ b/facefusion/uis/components/face_selector.py @@ -0,0 +1,224 @@ +from typing import List, Optional, Tuple + +import gradio +from gradio_rangeslider import RangeSlider + +import facefusion.choices +from facefusion import state_manager, wording +from facefusion.common_helper import calc_float_step, calc_int_step +from facefusion.face_analyser import get_many_faces +from facefusion.face_selector import sort_and_filter_faces +from facefusion.face_store import clear_reference_faces, clear_static_faces +from facefusion.filesystem import is_image, is_video +from facefusion.types import FaceSelectorMode, FaceSelectorOrder, Gender, Race, VisionFrame +from facefusion.uis.core import get_ui_component, get_ui_components, register_ui_component +from facefusion.uis.types import ComponentOptions +from facefusion.uis.ui_helper import convert_str_none +from facefusion.vision import normalize_frame_color, read_static_image, read_video_frame + +FACE_SELECTOR_MODE_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_SELECTOR_ORDER_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_SELECTOR_GENDER_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_SELECTOR_RACE_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_SELECTOR_AGE_RANGE_SLIDER : Optional[RangeSlider] = None +REFERENCE_FACE_POSITION_GALLERY : Optional[gradio.Gallery] = None +REFERENCE_FACE_DISTANCE_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global FACE_SELECTOR_MODE_DROPDOWN + global FACE_SELECTOR_ORDER_DROPDOWN + global FACE_SELECTOR_GENDER_DROPDOWN + global FACE_SELECTOR_RACE_DROPDOWN + global FACE_SELECTOR_AGE_RANGE_SLIDER + global REFERENCE_FACE_POSITION_GALLERY + global REFERENCE_FACE_DISTANCE_SLIDER + + reference_face_gallery_options : ComponentOptions =\ + { + 'label': wording.get('uis.reference_face_gallery'), + 'object_fit': 'cover', + 'columns': 7, + 'allow_preview': False, + 'elem_classes': 'box-face-selector', + 'visible': 'reference' in state_manager.get_item('face_selector_mode') + } + if is_image(state_manager.get_item('target_path')): + reference_frame = read_static_image(state_manager.get_item('target_path')) + reference_face_gallery_options['value'] = extract_gallery_frames(reference_frame) + if is_video(state_manager.get_item('target_path')): + reference_frame = read_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number')) + reference_face_gallery_options['value'] = extract_gallery_frames(reference_frame) + FACE_SELECTOR_MODE_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_selector_mode_dropdown'), + choices = facefusion.choices.face_selector_modes, + value = state_manager.get_item('face_selector_mode') + ) + REFERENCE_FACE_POSITION_GALLERY = gradio.Gallery(**reference_face_gallery_options) + with gradio.Group(): + with gradio.Row(): + FACE_SELECTOR_ORDER_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_selector_order_dropdown'), + choices = facefusion.choices.face_selector_orders, + value = state_manager.get_item('face_selector_order') + ) + FACE_SELECTOR_GENDER_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_selector_gender_dropdown'), + choices = [ 'none' ] + facefusion.choices.face_selector_genders, + value = state_manager.get_item('face_selector_gender') or 'none' + ) + FACE_SELECTOR_RACE_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_selector_race_dropdown'), + choices = ['none'] + facefusion.choices.face_selector_races, + value = state_manager.get_item('face_selector_race') or 'none' + ) + with gradio.Row(): + face_selector_age_start = state_manager.get_item('face_selector_age_start') or facefusion.choices.face_selector_age_range[0] + face_selector_age_end = state_manager.get_item('face_selector_age_end') or facefusion.choices.face_selector_age_range[-1] + FACE_SELECTOR_AGE_RANGE_SLIDER = RangeSlider( + label = wording.get('uis.face_selector_age_range_slider'), + minimum = facefusion.choices.face_selector_age_range[0], + maximum = facefusion.choices.face_selector_age_range[-1], + value = (face_selector_age_start, face_selector_age_end), + step = calc_int_step(facefusion.choices.face_selector_age_range) + ) + REFERENCE_FACE_DISTANCE_SLIDER = gradio.Slider( + label = wording.get('uis.reference_face_distance_slider'), + value = state_manager.get_item('reference_face_distance'), + step = calc_float_step(facefusion.choices.reference_face_distance_range), + minimum = facefusion.choices.reference_face_distance_range[0], + maximum = facefusion.choices.reference_face_distance_range[-1], + visible = 'reference' in state_manager.get_item('face_selector_mode') + ) + register_ui_component('face_selector_mode_dropdown', FACE_SELECTOR_MODE_DROPDOWN) + register_ui_component('face_selector_order_dropdown', FACE_SELECTOR_ORDER_DROPDOWN) + register_ui_component('face_selector_gender_dropdown', FACE_SELECTOR_GENDER_DROPDOWN) + register_ui_component('face_selector_race_dropdown', FACE_SELECTOR_RACE_DROPDOWN) + register_ui_component('face_selector_age_range_slider', FACE_SELECTOR_AGE_RANGE_SLIDER) + register_ui_component('reference_face_position_gallery', REFERENCE_FACE_POSITION_GALLERY) + register_ui_component('reference_face_distance_slider', REFERENCE_FACE_DISTANCE_SLIDER) + + +def listen() -> None: + FACE_SELECTOR_MODE_DROPDOWN.change(update_face_selector_mode, inputs = FACE_SELECTOR_MODE_DROPDOWN, outputs = [ REFERENCE_FACE_POSITION_GALLERY, REFERENCE_FACE_DISTANCE_SLIDER ]) + FACE_SELECTOR_ORDER_DROPDOWN.change(update_face_selector_order, inputs = FACE_SELECTOR_ORDER_DROPDOWN, outputs = REFERENCE_FACE_POSITION_GALLERY) + FACE_SELECTOR_GENDER_DROPDOWN.change(update_face_selector_gender, inputs = FACE_SELECTOR_GENDER_DROPDOWN, outputs = REFERENCE_FACE_POSITION_GALLERY) + FACE_SELECTOR_RACE_DROPDOWN.change(update_face_selector_race, inputs = FACE_SELECTOR_RACE_DROPDOWN, outputs = REFERENCE_FACE_POSITION_GALLERY) + FACE_SELECTOR_AGE_RANGE_SLIDER.release(update_face_selector_age_range, inputs = FACE_SELECTOR_AGE_RANGE_SLIDER, outputs = REFERENCE_FACE_POSITION_GALLERY) + REFERENCE_FACE_POSITION_GALLERY.select(clear_and_update_reference_face_position) + REFERENCE_FACE_DISTANCE_SLIDER.release(update_reference_face_distance, inputs = REFERENCE_FACE_DISTANCE_SLIDER) + + for ui_component in get_ui_components( + [ + 'target_image', + 'target_video' + ]): + for method in [ 'change', 'clear' ]: + getattr(ui_component, method)(update_reference_face_position) + getattr(ui_component, method)(update_reference_position_gallery, outputs = REFERENCE_FACE_POSITION_GALLERY) + + for ui_component in get_ui_components( + [ + 'face_detector_model_dropdown', + 'face_detector_size_dropdown', + 'face_detector_angles_checkbox_group' + ]): + ui_component.change(clear_and_update_reference_position_gallery, outputs = REFERENCE_FACE_POSITION_GALLERY) + + face_detector_score_slider = get_ui_component('face_detector_score_slider') + if face_detector_score_slider: + face_detector_score_slider.release(clear_and_update_reference_position_gallery, outputs = REFERENCE_FACE_POSITION_GALLERY) + + preview_frame_slider = get_ui_component('preview_frame_slider') + if preview_frame_slider: + for method in [ 'change', 'release' ]: + getattr(preview_frame_slider, method)(update_reference_frame_number, inputs = preview_frame_slider, show_progress = 'hidden') + getattr(preview_frame_slider, method)(update_reference_position_gallery, outputs = REFERENCE_FACE_POSITION_GALLERY, show_progress = 'hidden') + + +def update_face_selector_mode(face_selector_mode : FaceSelectorMode) -> Tuple[gradio.Gallery, gradio.Slider]: + state_manager.set_item('face_selector_mode', face_selector_mode) + if face_selector_mode == 'many': + return gradio.Gallery(visible = False), gradio.Slider(visible = False) + if face_selector_mode == 'one': + return gradio.Gallery(visible = False), gradio.Slider(visible = False) + if face_selector_mode == 'reference': + return gradio.Gallery(visible = True), gradio.Slider(visible = True) + + +def update_face_selector_order(face_analyser_order : FaceSelectorOrder) -> gradio.Gallery: + state_manager.set_item('face_selector_order', convert_str_none(face_analyser_order)) + return update_reference_position_gallery() + + +def update_face_selector_gender(face_selector_gender : Gender) -> gradio.Gallery: + state_manager.set_item('face_selector_gender', convert_str_none(face_selector_gender)) + return update_reference_position_gallery() + + +def update_face_selector_race(face_selector_race : Race) -> gradio.Gallery: + state_manager.set_item('face_selector_race', convert_str_none(face_selector_race)) + return update_reference_position_gallery() + + +def update_face_selector_age_range(face_selector_age_range : Tuple[float, float]) -> gradio.Gallery: + face_selector_age_start, face_selector_age_end = face_selector_age_range + state_manager.set_item('face_selector_age_start', int(face_selector_age_start)) + state_manager.set_item('face_selector_age_end', int(face_selector_age_end)) + return update_reference_position_gallery() + + +def clear_and_update_reference_face_position(event : gradio.SelectData) -> gradio.Gallery: + clear_reference_faces() + clear_static_faces() + update_reference_face_position(event.index) + return update_reference_position_gallery() + + +def update_reference_face_position(reference_face_position : int = 0) -> None: + state_manager.set_item('reference_face_position', reference_face_position) + + +def update_reference_face_distance(reference_face_distance : float) -> None: + state_manager.set_item('reference_face_distance', reference_face_distance) + + +def update_reference_frame_number(reference_frame_number : int) -> None: + state_manager.set_item('reference_frame_number', reference_frame_number) + + +def clear_and_update_reference_position_gallery() -> gradio.Gallery: + clear_reference_faces() + clear_static_faces() + return update_reference_position_gallery() + + +def update_reference_position_gallery() -> gradio.Gallery: + gallery_vision_frames = [] + if is_image(state_manager.get_item('target_path')): + temp_vision_frame = read_static_image(state_manager.get_item('target_path')) + gallery_vision_frames = extract_gallery_frames(temp_vision_frame) + if is_video(state_manager.get_item('target_path')): + temp_vision_frame = read_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number')) + gallery_vision_frames = extract_gallery_frames(temp_vision_frame) + if gallery_vision_frames: + return gradio.Gallery(value = gallery_vision_frames) + return gradio.Gallery(value = None) + + +def extract_gallery_frames(temp_vision_frame : VisionFrame) -> List[VisionFrame]: + gallery_vision_frames = [] + faces = sort_and_filter_faces(get_many_faces([ temp_vision_frame ])) + + for face in faces: + start_x, start_y, end_x, end_y = map(int, face.bounding_box) + padding_x = int((end_x - start_x) * 0.25) + padding_y = int((end_y - start_y) * 0.25) + start_x = max(0, start_x - padding_x) + start_y = max(0, start_y - padding_y) + end_x = max(0, end_x + padding_x) + end_y = max(0, end_y + padding_y) + crop_vision_frame = temp_vision_frame[start_y:end_y, start_x:end_x] + crop_vision_frame = normalize_frame_color(crop_vision_frame) + gallery_vision_frames.append(crop_vision_frame) + return gallery_vision_frames diff --git a/facefusion/uis/components/face_swapper_options.py b/facefusion/uis/components/face_swapper_options.py new file mode 100644 index 0000000000000000000000000000000000000000..92f08dc3c4b96aa02851185212aa3fa46db224bf --- /dev/null +++ b/facefusion/uis/components/face_swapper_options.py @@ -0,0 +1,64 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import get_first +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import FaceSwapperModel +from facefusion.uis.core import get_ui_component, register_ui_component + +FACE_SWAPPER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FACE_SWAPPER_PIXEL_BOOST_DROPDOWN : Optional[gradio.Dropdown] = None + + +def render() -> None: + global FACE_SWAPPER_MODEL_DROPDOWN + global FACE_SWAPPER_PIXEL_BOOST_DROPDOWN + + has_face_swapper = 'face_swapper' in state_manager.get_item('processors') + FACE_SWAPPER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_swapper_model_dropdown'), + choices = processors_choices.face_swapper_models, + value = state_manager.get_item('face_swapper_model'), + visible = has_face_swapper + ) + FACE_SWAPPER_PIXEL_BOOST_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.face_swapper_pixel_boost_dropdown'), + choices = processors_choices.face_swapper_set.get(state_manager.get_item('face_swapper_model')), + value = state_manager.get_item('face_swapper_pixel_boost'), + visible = has_face_swapper + ) + register_ui_component('face_swapper_model_dropdown', FACE_SWAPPER_MODEL_DROPDOWN) + register_ui_component('face_swapper_pixel_boost_dropdown', FACE_SWAPPER_PIXEL_BOOST_DROPDOWN) + + +def listen() -> None: + FACE_SWAPPER_MODEL_DROPDOWN.change(update_face_swapper_model, inputs = FACE_SWAPPER_MODEL_DROPDOWN, outputs = [ FACE_SWAPPER_MODEL_DROPDOWN, FACE_SWAPPER_PIXEL_BOOST_DROPDOWN ]) + FACE_SWAPPER_PIXEL_BOOST_DROPDOWN.change(update_face_swapper_pixel_boost, inputs = FACE_SWAPPER_PIXEL_BOOST_DROPDOWN) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ FACE_SWAPPER_MODEL_DROPDOWN, FACE_SWAPPER_PIXEL_BOOST_DROPDOWN ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Dropdown]: + has_face_swapper = 'face_swapper' in processors + return gradio.Dropdown(visible = has_face_swapper), gradio.Dropdown(visible = has_face_swapper) + + +def update_face_swapper_model(face_swapper_model : FaceSwapperModel) -> Tuple[gradio.Dropdown, gradio.Dropdown]: + face_swapper_module = load_processor_module('face_swapper') + face_swapper_module.clear_inference_pool() + state_manager.set_item('face_swapper_model', face_swapper_model) + + if face_swapper_module.pre_check(): + face_swapper_pixel_boost_choices = processors_choices.face_swapper_set.get(state_manager.get_item('face_swapper_model')) + state_manager.set_item('face_swapper_pixel_boost', get_first(face_swapper_pixel_boost_choices)) + return gradio.Dropdown(value = state_manager.get_item('face_swapper_model')), gradio.Dropdown(value = state_manager.get_item('face_swapper_pixel_boost'), choices = face_swapper_pixel_boost_choices) + return gradio.Dropdown(), gradio.Dropdown() + + +def update_face_swapper_pixel_boost(face_swapper_pixel_boost : str) -> None: + state_manager.set_item('face_swapper_pixel_boost', face_swapper_pixel_boost) diff --git a/facefusion/uis/components/frame_colorizer_options.py b/facefusion/uis/components/frame_colorizer_options.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef4a47ea83de4b4bed17e14fcfce77a49e0b004 --- /dev/null +++ b/facefusion/uis/components/frame_colorizer_options.py @@ -0,0 +1,81 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import calc_int_step +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import FrameColorizerModel +from facefusion.uis.core import get_ui_component, register_ui_component + +FRAME_COLORIZER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FRAME_COLORIZER_SIZE_DROPDOWN : Optional[gradio.Dropdown] = None +FRAME_COLORIZER_BLEND_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global FRAME_COLORIZER_MODEL_DROPDOWN + global FRAME_COLORIZER_SIZE_DROPDOWN + global FRAME_COLORIZER_BLEND_SLIDER + + has_frame_colorizer = 'frame_colorizer' in state_manager.get_item('processors') + FRAME_COLORIZER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.frame_colorizer_model_dropdown'), + choices = processors_choices.frame_colorizer_models, + value = state_manager.get_item('frame_colorizer_model'), + visible = has_frame_colorizer + ) + FRAME_COLORIZER_SIZE_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.frame_colorizer_size_dropdown'), + choices = processors_choices.frame_colorizer_sizes, + value = state_manager.get_item('frame_colorizer_size'), + visible = has_frame_colorizer + ) + FRAME_COLORIZER_BLEND_SLIDER = gradio.Slider( + label = wording.get('uis.frame_colorizer_blend_slider'), + value = state_manager.get_item('frame_colorizer_blend'), + step = calc_int_step(processors_choices.frame_colorizer_blend_range), + minimum = processors_choices.frame_colorizer_blend_range[0], + maximum = processors_choices.frame_colorizer_blend_range[-1], + visible = has_frame_colorizer + ) + register_ui_component('frame_colorizer_model_dropdown', FRAME_COLORIZER_MODEL_DROPDOWN) + register_ui_component('frame_colorizer_size_dropdown', FRAME_COLORIZER_SIZE_DROPDOWN) + register_ui_component('frame_colorizer_blend_slider', FRAME_COLORIZER_BLEND_SLIDER) + + +def listen() -> None: + FRAME_COLORIZER_MODEL_DROPDOWN.change(update_frame_colorizer_model, inputs = FRAME_COLORIZER_MODEL_DROPDOWN, outputs = FRAME_COLORIZER_MODEL_DROPDOWN) + FRAME_COLORIZER_SIZE_DROPDOWN.change(update_frame_colorizer_size, inputs = FRAME_COLORIZER_SIZE_DROPDOWN) + FRAME_COLORIZER_BLEND_SLIDER.release(update_frame_colorizer_blend, inputs = FRAME_COLORIZER_BLEND_SLIDER) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ FRAME_COLORIZER_MODEL_DROPDOWN, FRAME_COLORIZER_BLEND_SLIDER, FRAME_COLORIZER_SIZE_DROPDOWN ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider, gradio.Dropdown]: + has_frame_colorizer = 'frame_colorizer' in processors + return gradio.Dropdown(visible = has_frame_colorizer), gradio.Slider(visible = has_frame_colorizer), gradio.Dropdown(visible = has_frame_colorizer) + + +def update_frame_colorizer_model(frame_colorizer_model : FrameColorizerModel) -> gradio.Dropdown: + frame_colorizer_module = load_processor_module('frame_colorizer') + frame_colorizer_module.clear_inference_pool() + state_manager.set_item('frame_colorizer_model', frame_colorizer_model) + + if frame_colorizer_module.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('frame_colorizer_model')) + return gradio.Dropdown() + + +def update_frame_colorizer_size(frame_colorizer_size : str) -> None: + state_manager.set_item('frame_colorizer_size', frame_colorizer_size) + + +def update_frame_colorizer_blend(frame_colorizer_blend : float) -> None: + state_manager.set_item('frame_colorizer_blend', int(frame_colorizer_blend)) + + + diff --git a/facefusion/uis/components/frame_enhancer_options.py b/facefusion/uis/components/frame_enhancer_options.py new file mode 100644 index 0000000000000000000000000000000000000000..db0df537d4078120d17d1adede3c46420f3dfb7c --- /dev/null +++ b/facefusion/uis/components/frame_enhancer_options.py @@ -0,0 +1,64 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import calc_int_step +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import FrameEnhancerModel +from facefusion.uis.core import get_ui_component, register_ui_component + +FRAME_ENHANCER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +FRAME_ENHANCER_BLEND_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global FRAME_ENHANCER_MODEL_DROPDOWN + global FRAME_ENHANCER_BLEND_SLIDER + + has_frame_enhancer = 'frame_enhancer' in state_manager.get_item('processors') + FRAME_ENHANCER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.frame_enhancer_model_dropdown'), + choices = processors_choices.frame_enhancer_models, + value = state_manager.get_item('frame_enhancer_model'), + visible = has_frame_enhancer + ) + FRAME_ENHANCER_BLEND_SLIDER = gradio.Slider( + label = wording.get('uis.frame_enhancer_blend_slider'), + value = state_manager.get_item('frame_enhancer_blend'), + step = calc_int_step(processors_choices.frame_enhancer_blend_range), + minimum = processors_choices.frame_enhancer_blend_range[0], + maximum = processors_choices.frame_enhancer_blend_range[-1], + visible = has_frame_enhancer + ) + register_ui_component('frame_enhancer_model_dropdown', FRAME_ENHANCER_MODEL_DROPDOWN) + register_ui_component('frame_enhancer_blend_slider', FRAME_ENHANCER_BLEND_SLIDER) + + +def listen() -> None: + FRAME_ENHANCER_MODEL_DROPDOWN.change(update_frame_enhancer_model, inputs = FRAME_ENHANCER_MODEL_DROPDOWN, outputs = FRAME_ENHANCER_MODEL_DROPDOWN) + FRAME_ENHANCER_BLEND_SLIDER.release(update_frame_enhancer_blend, inputs = FRAME_ENHANCER_BLEND_SLIDER) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ FRAME_ENHANCER_MODEL_DROPDOWN, FRAME_ENHANCER_BLEND_SLIDER ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider]: + has_frame_enhancer = 'frame_enhancer' in processors + return gradio.Dropdown(visible = has_frame_enhancer), gradio.Slider(visible = has_frame_enhancer) + + +def update_frame_enhancer_model(frame_enhancer_model : FrameEnhancerModel) -> gradio.Dropdown: + frame_enhancer_module = load_processor_module('frame_enhancer') + frame_enhancer_module.clear_inference_pool() + state_manager.set_item('frame_enhancer_model', frame_enhancer_model) + + if frame_enhancer_module.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('frame_enhancer_model')) + return gradio.Dropdown() + + +def update_frame_enhancer_blend(frame_enhancer_blend : float) -> None: + state_manager.set_item('frame_enhancer_blend', int(frame_enhancer_blend)) diff --git a/facefusion/uis/components/instant_runner.py b/facefusion/uis/components/instant_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..71a3f7ab8e9e253425822d9c80134cfd11ca4bd7 --- /dev/null +++ b/facefusion/uis/components/instant_runner.py @@ -0,0 +1,110 @@ +from time import sleep +from typing import Optional, Tuple + +import gradio + +from facefusion import process_manager, state_manager, wording +from facefusion.args import collect_step_args +from facefusion.core import process_step +from facefusion.filesystem import is_directory, is_image, is_video +from facefusion.jobs import job_helper, job_manager, job_runner, job_store +from facefusion.temp_helper import clear_temp_directory +from facefusion.types import Args, UiWorkflow +from facefusion.uis.core import get_ui_component +from facefusion.uis.ui_helper import suggest_output_path + +INSTANT_RUNNER_WRAPPER : Optional[gradio.Row] = None +INSTANT_RUNNER_START_BUTTON : Optional[gradio.Button] = None +INSTANT_RUNNER_STOP_BUTTON : Optional[gradio.Button] = None +INSTANT_RUNNER_CLEAR_BUTTON : Optional[gradio.Button] = None + + +def render() -> None: + global INSTANT_RUNNER_WRAPPER + global INSTANT_RUNNER_START_BUTTON + global INSTANT_RUNNER_STOP_BUTTON + global INSTANT_RUNNER_CLEAR_BUTTON + + if job_manager.init_jobs(state_manager.get_item('jobs_path')): + is_instant_runner = state_manager.get_item('ui_workflow') == 'instant_runner' + + with gradio.Row(visible = is_instant_runner) as INSTANT_RUNNER_WRAPPER: + INSTANT_RUNNER_START_BUTTON = gradio.Button( + value = wording.get('uis.start_button'), + variant = 'primary', + size = 'sm' + ) + INSTANT_RUNNER_STOP_BUTTON = gradio.Button( + value = wording.get('uis.stop_button'), + variant = 'primary', + size = 'sm', + visible = False + ) + INSTANT_RUNNER_CLEAR_BUTTON = gradio.Button( + value = wording.get('uis.clear_button'), + size = 'sm' + ) + + +def listen() -> None: + output_image = get_ui_component('output_image') + output_video = get_ui_component('output_video') + ui_workflow_dropdown = get_ui_component('ui_workflow_dropdown') + + if output_image and output_video: + INSTANT_RUNNER_START_BUTTON.click(start, outputs = [ INSTANT_RUNNER_START_BUTTON, INSTANT_RUNNER_STOP_BUTTON ]) + INSTANT_RUNNER_START_BUTTON.click(run, outputs = [ INSTANT_RUNNER_START_BUTTON, INSTANT_RUNNER_STOP_BUTTON, output_image, output_video ]) + INSTANT_RUNNER_STOP_BUTTON.click(stop, outputs = [ INSTANT_RUNNER_START_BUTTON, INSTANT_RUNNER_STOP_BUTTON ]) + INSTANT_RUNNER_CLEAR_BUTTON.click(clear, outputs = [ output_image, output_video ]) + if ui_workflow_dropdown: + ui_workflow_dropdown.change(remote_update, inputs = ui_workflow_dropdown, outputs = INSTANT_RUNNER_WRAPPER) + + +def remote_update(ui_workflow : UiWorkflow) -> gradio.Row: + is_instant_runner = ui_workflow == 'instant_runner' + + return gradio.Row(visible = is_instant_runner) + + +def start() -> Tuple[gradio.Button, gradio.Button]: + while not process_manager.is_processing(): + sleep(0.5) + return gradio.Button(visible = False), gradio.Button(visible = True) + + +def run() -> Tuple[gradio.Button, gradio.Button, gradio.Image, gradio.Video]: + step_args = collect_step_args() + output_path = step_args.get('output_path') + + if is_directory(step_args.get('output_path')): + step_args['output_path'] = suggest_output_path(step_args.get('output_path'), state_manager.get_item('target_path')) + if job_manager.init_jobs(state_manager.get_item('jobs_path')): + create_and_run_job(step_args) + state_manager.set_item('output_path', output_path) + if is_image(step_args.get('output_path')): + return gradio.Button(visible = True), gradio.Button(visible = False), gradio.Image(value = step_args.get('output_path'), visible = True), gradio.Video(value = None, visible = False) + if is_video(step_args.get('output_path')): + return gradio.Button(visible = True), gradio.Button(visible = False), gradio.Image(value = None, visible = False), gradio.Video(value = step_args.get('output_path'), visible = True) + return gradio.Button(visible = True), gradio.Button(visible = False), gradio.Image(value = None), gradio.Video(value = None) + + +def create_and_run_job(step_args : Args) -> bool: + job_id = job_helper.suggest_job_id('ui') + + for key in job_store.get_job_keys(): + state_manager.sync_item(key) #type:ignore[arg-type] + + return job_manager.create_job(job_id) and job_manager.add_step(job_id, step_args) and job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step) + + +def stop() -> Tuple[gradio.Button, gradio.Button]: + process_manager.stop() + return gradio.Button(visible = True), gradio.Button(visible = False) + + +def clear() -> Tuple[gradio.Image, gradio.Video]: + while process_manager.is_processing(): + sleep(0.5) + if state_manager.get_item('target_path'): + clear_temp_directory(state_manager.get_item('target_path')) + return gradio.Image(value = None), gradio.Video(value = None) diff --git a/facefusion/uis/components/job_list.py b/facefusion/uis/components/job_list.py new file mode 100644 index 0000000000000000000000000000000000000000..bb954cf0a7fb4134be6c9fba35a06f56b2bb1d10 --- /dev/null +++ b/facefusion/uis/components/job_list.py @@ -0,0 +1,50 @@ +from typing import List, Optional + +import gradio + +import facefusion.choices +from facefusion import state_manager, wording +from facefusion.common_helper import get_first +from facefusion.jobs import job_list, job_manager +from facefusion.types import JobStatus +from facefusion.uis.core import get_ui_component + +JOB_LIST_JOBS_DATAFRAME : Optional[gradio.Dataframe] = None +JOB_LIST_REFRESH_BUTTON : Optional[gradio.Button] = None + + +def render() -> None: + global JOB_LIST_JOBS_DATAFRAME + global JOB_LIST_REFRESH_BUTTON + + if job_manager.init_jobs(state_manager.get_item('jobs_path')): + job_status = get_first(facefusion.choices.job_statuses) + job_headers, job_contents = job_list.compose_job_list(job_status) + + JOB_LIST_JOBS_DATAFRAME = gradio.Dataframe( + headers = job_headers, + value = job_contents, + datatype = [ 'str', 'number', 'date', 'date', 'str' ], + show_label = False + ) + JOB_LIST_REFRESH_BUTTON = gradio.Button( + value = wording.get('uis.refresh_button'), + variant = 'primary', + size = 'sm' + ) + + +def listen() -> None: + job_list_job_status_checkbox_group = get_ui_component('job_list_job_status_checkbox_group') + if job_list_job_status_checkbox_group: + job_list_job_status_checkbox_group.change(update_job_dataframe, inputs = job_list_job_status_checkbox_group, outputs = JOB_LIST_JOBS_DATAFRAME) + JOB_LIST_REFRESH_BUTTON.click(update_job_dataframe, inputs = job_list_job_status_checkbox_group, outputs = JOB_LIST_JOBS_DATAFRAME) + + +def update_job_dataframe(job_statuses : List[JobStatus]) -> gradio.Dataframe: + all_job_contents = [] + + for job_status in job_statuses: + _, job_contents = job_list.compose_job_list(job_status) + all_job_contents.extend(job_contents) + return gradio.Dataframe(value = all_job_contents) diff --git a/facefusion/uis/components/job_list_options.py b/facefusion/uis/components/job_list_options.py new file mode 100644 index 0000000000000000000000000000000000000000..eae763eba1fd11ee659b7a93a3941b2eae4f90e4 --- /dev/null +++ b/facefusion/uis/components/job_list_options.py @@ -0,0 +1,35 @@ +from typing import List, Optional + +import gradio + +import facefusion.choices +from facefusion import state_manager, wording +from facefusion.common_helper import get_first +from facefusion.jobs import job_manager +from facefusion.types import JobStatus +from facefusion.uis.core import register_ui_component + +JOB_LIST_JOB_STATUS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None + + +def render() -> None: + global JOB_LIST_JOB_STATUS_CHECKBOX_GROUP + + if job_manager.init_jobs(state_manager.get_item('jobs_path')): + job_status = get_first(facefusion.choices.job_statuses) + + JOB_LIST_JOB_STATUS_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.job_list_status_checkbox_group'), + choices = facefusion.choices.job_statuses, + value = job_status + ) + register_ui_component('job_list_job_status_checkbox_group', JOB_LIST_JOB_STATUS_CHECKBOX_GROUP) + + +def listen() -> None: + JOB_LIST_JOB_STATUS_CHECKBOX_GROUP.change(update_job_status_checkbox_group, inputs = JOB_LIST_JOB_STATUS_CHECKBOX_GROUP, outputs = JOB_LIST_JOB_STATUS_CHECKBOX_GROUP) + + +def update_job_status_checkbox_group(job_statuses : List[JobStatus]) -> gradio.CheckboxGroup: + job_statuses = job_statuses or facefusion.choices.job_statuses + return gradio.CheckboxGroup(value = job_statuses) diff --git a/facefusion/uis/components/job_manager.py b/facefusion/uis/components/job_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..618af955c84479ed0c3e4f3a42dd8f2f2fd4e277 --- /dev/null +++ b/facefusion/uis/components/job_manager.py @@ -0,0 +1,194 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import logger, state_manager, wording +from facefusion.args import collect_step_args +from facefusion.common_helper import get_first, get_last +from facefusion.filesystem import is_directory +from facefusion.jobs import job_manager +from facefusion.types import UiWorkflow +from facefusion.uis import choices as uis_choices +from facefusion.uis.core import get_ui_component +from facefusion.uis.types import JobManagerAction +from facefusion.uis.ui_helper import convert_int_none, convert_str_none, suggest_output_path + +JOB_MANAGER_WRAPPER : Optional[gradio.Column] = None +JOB_MANAGER_JOB_ACTION_DROPDOWN : Optional[gradio.Dropdown] = None +JOB_MANAGER_JOB_ID_TEXTBOX : Optional[gradio.Textbox] = None +JOB_MANAGER_JOB_ID_DROPDOWN : Optional[gradio.Dropdown] = None +JOB_MANAGER_STEP_INDEX_DROPDOWN : Optional[gradio.Dropdown] = None +JOB_MANAGER_APPLY_BUTTON : Optional[gradio.Button] = None + + +def render() -> None: + global JOB_MANAGER_WRAPPER + global JOB_MANAGER_JOB_ACTION_DROPDOWN + global JOB_MANAGER_JOB_ID_TEXTBOX + global JOB_MANAGER_JOB_ID_DROPDOWN + global JOB_MANAGER_STEP_INDEX_DROPDOWN + global JOB_MANAGER_APPLY_BUTTON + + if job_manager.init_jobs(state_manager.get_item('jobs_path')): + is_job_manager = state_manager.get_item('ui_workflow') == 'job_manager' + drafted_job_ids = job_manager.find_job_ids('drafted') or [ 'none' ] + + with gradio.Column(visible = is_job_manager) as JOB_MANAGER_WRAPPER: + JOB_MANAGER_JOB_ACTION_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.job_manager_job_action_dropdown'), + choices = uis_choices.job_manager_actions, + value = get_first(uis_choices.job_manager_actions) + ) + JOB_MANAGER_JOB_ID_TEXTBOX = gradio.Textbox( + label = wording.get('uis.job_manager_job_id_dropdown'), + max_lines = 1, + interactive = True + ) + JOB_MANAGER_JOB_ID_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.job_manager_job_id_dropdown'), + choices = drafted_job_ids, + value = get_last(drafted_job_ids), + interactive = True, + visible = False + ) + JOB_MANAGER_STEP_INDEX_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.job_manager_step_index_dropdown'), + choices = [ 'none' ], + value = 'none', + interactive = True, + visible = False + ) + JOB_MANAGER_APPLY_BUTTON = gradio.Button( + value = wording.get('uis.apply_button'), + variant = 'primary', + size = 'sm' + ) + + +def listen() -> None: + JOB_MANAGER_JOB_ACTION_DROPDOWN.change(update, inputs = [ JOB_MANAGER_JOB_ACTION_DROPDOWN, JOB_MANAGER_JOB_ID_DROPDOWN ], outputs = [ JOB_MANAGER_JOB_ID_TEXTBOX, JOB_MANAGER_JOB_ID_DROPDOWN, JOB_MANAGER_STEP_INDEX_DROPDOWN ]) + JOB_MANAGER_JOB_ID_DROPDOWN.change(update_step_index, inputs = JOB_MANAGER_JOB_ID_DROPDOWN, outputs = JOB_MANAGER_STEP_INDEX_DROPDOWN) + JOB_MANAGER_APPLY_BUTTON.click(apply, inputs = [ JOB_MANAGER_JOB_ACTION_DROPDOWN, JOB_MANAGER_JOB_ID_TEXTBOX, JOB_MANAGER_JOB_ID_DROPDOWN, JOB_MANAGER_STEP_INDEX_DROPDOWN ], outputs = [ JOB_MANAGER_JOB_ACTION_DROPDOWN, JOB_MANAGER_JOB_ID_TEXTBOX, JOB_MANAGER_JOB_ID_DROPDOWN, JOB_MANAGER_STEP_INDEX_DROPDOWN ]) + + ui_workflow_dropdown = get_ui_component('ui_workflow_dropdown') + if ui_workflow_dropdown: + ui_workflow_dropdown.change(remote_update, inputs = ui_workflow_dropdown, outputs = [ JOB_MANAGER_WRAPPER, JOB_MANAGER_JOB_ACTION_DROPDOWN, JOB_MANAGER_JOB_ID_TEXTBOX, JOB_MANAGER_JOB_ID_DROPDOWN, JOB_MANAGER_STEP_INDEX_DROPDOWN ]) + + +def remote_update(ui_workflow : UiWorkflow) -> Tuple[gradio.Row, gradio.Dropdown, gradio.Textbox, gradio.Dropdown, gradio.Dropdown]: + is_job_manager = ui_workflow == 'job_manager' + return gradio.Row(visible = is_job_manager), gradio.Dropdown(value = get_first(uis_choices.job_manager_actions)), gradio.Textbox(value = None, visible = True), gradio.Dropdown(visible = False), gradio.Dropdown(visible = False) + + +def apply(job_action : JobManagerAction, created_job_id : str, selected_job_id : str, selected_step_index : int) -> Tuple[gradio.Dropdown, gradio.Textbox, gradio.Dropdown, gradio.Dropdown]: + created_job_id = convert_str_none(created_job_id) + selected_job_id = convert_str_none(selected_job_id) + selected_step_index = convert_int_none(selected_step_index) + step_args = collect_step_args() + output_path = step_args.get('output_path') + + if is_directory(step_args.get('output_path')): + step_args['output_path'] = suggest_output_path(step_args.get('output_path'), state_manager.get_item('target_path')) + + if job_action == 'job-create': + if created_job_id and job_manager.create_job(created_job_id): + updated_job_ids = job_manager.find_job_ids('drafted') or [ 'none' ] + + logger.info(wording.get('job_created').format(job_id = created_job_id), __name__) + return gradio.Dropdown(value = 'job-add-step'), gradio.Textbox(visible = False), gradio.Dropdown(value = created_job_id, choices = updated_job_ids, visible = True), gradio.Dropdown() + else: + logger.error(wording.get('job_not_created').format(job_id = created_job_id), __name__) + + if job_action == 'job-submit': + if selected_job_id and job_manager.submit_job(selected_job_id): + updated_job_ids = job_manager.find_job_ids('drafted') or [ 'none' ] + + logger.info(wording.get('job_submitted').format(job_id = selected_job_id), __name__) + return gradio.Dropdown(), gradio.Textbox(), gradio.Dropdown(value = get_last(updated_job_ids), choices = updated_job_ids, visible = True), gradio.Dropdown() + else: + logger.error(wording.get('job_not_submitted').format(job_id = selected_job_id), __name__) + + if job_action == 'job-delete': + if selected_job_id and job_manager.delete_job(selected_job_id): + updated_job_ids = job_manager.find_job_ids('drafted') + job_manager.find_job_ids('queued') + job_manager.find_job_ids('failed') + job_manager.find_job_ids('completed') or [ 'none' ] + + logger.info(wording.get('job_deleted').format(job_id = selected_job_id), __name__) + return gradio.Dropdown(), gradio.Textbox(), gradio.Dropdown(value = get_last(updated_job_ids), choices = updated_job_ids, visible = True), gradio.Dropdown() + else: + logger.error(wording.get('job_not_deleted').format(job_id = selected_job_id), __name__) + + if job_action == 'job-add-step': + if selected_job_id and job_manager.add_step(selected_job_id, step_args): + state_manager.set_item('output_path', output_path) + logger.info(wording.get('job_step_added').format(job_id = selected_job_id), __name__) + return gradio.Dropdown(), gradio.Textbox(), gradio.Dropdown(visible = True), gradio.Dropdown(visible = False) + else: + state_manager.set_item('output_path', output_path) + logger.error(wording.get('job_step_not_added').format(job_id = selected_job_id), __name__) + + if job_action == 'job-remix-step': + if selected_job_id and job_manager.has_step(selected_job_id, selected_step_index) and job_manager.remix_step(selected_job_id, selected_step_index, step_args): + updated_step_choices = get_step_choices(selected_job_id) or [ 'none' ] #type:ignore[list-item] + + state_manager.set_item('output_path', output_path) + logger.info(wording.get('job_remix_step_added').format(job_id = selected_job_id, step_index = selected_step_index), __name__) + return gradio.Dropdown(), gradio.Textbox(), gradio.Dropdown(visible = True), gradio.Dropdown(value = get_last(updated_step_choices), choices = updated_step_choices, visible = True) + else: + state_manager.set_item('output_path', output_path) + logger.error(wording.get('job_remix_step_not_added').format(job_id = selected_job_id, step_index = selected_step_index), __name__) + + if job_action == 'job-insert-step': + if selected_job_id and job_manager.has_step(selected_job_id, selected_step_index) and job_manager.insert_step(selected_job_id, selected_step_index, step_args): + updated_step_choices = get_step_choices(selected_job_id) or [ 'none' ] #type:ignore[list-item] + + state_manager.set_item('output_path', output_path) + logger.info(wording.get('job_step_inserted').format(job_id = selected_job_id, step_index = selected_step_index), __name__) + return gradio.Dropdown(), gradio.Textbox(), gradio.Dropdown(visible = True), gradio.Dropdown(value = get_last(updated_step_choices), choices = updated_step_choices, visible = True) + else: + state_manager.set_item('output_path', output_path) + logger.error(wording.get('job_step_not_inserted').format(job_id = selected_job_id, step_index = selected_step_index), __name__) + + if job_action == 'job-remove-step': + if selected_job_id and job_manager.has_step(selected_job_id, selected_step_index) and job_manager.remove_step(selected_job_id, selected_step_index): + updated_step_choices = get_step_choices(selected_job_id) or [ 'none' ] #type:ignore[list-item] + + logger.info(wording.get('job_step_removed').format(job_id = selected_job_id, step_index = selected_step_index), __name__) + return gradio.Dropdown(), gradio.Textbox(), gradio.Dropdown(visible = True), gradio.Dropdown(value = get_last(updated_step_choices), choices = updated_step_choices, visible = True) + else: + logger.error(wording.get('job_step_not_removed').format(job_id = selected_job_id, step_index = selected_step_index), __name__) + return gradio.Dropdown(), gradio.Textbox(), gradio.Dropdown(), gradio.Dropdown() + + +def get_step_choices(job_id : str) -> List[int]: + steps = job_manager.get_steps(job_id) + return [ index for index, _ in enumerate(steps) ] + + +def update(job_action : JobManagerAction, selected_job_id : str) -> Tuple[gradio.Textbox, gradio.Dropdown, gradio.Dropdown]: + if job_action == 'job-create': + return gradio.Textbox(value = None, visible = True), gradio.Dropdown(visible = False), gradio.Dropdown(visible = False) + + if job_action == 'job-delete': + updated_job_ids = job_manager.find_job_ids('drafted') + job_manager.find_job_ids('queued') + job_manager.find_job_ids('failed') + job_manager.find_job_ids('completed') or [ 'none' ] + updated_job_id = selected_job_id if selected_job_id in updated_job_ids else get_last(updated_job_ids) + + return gradio.Textbox(visible = False), gradio.Dropdown(value = updated_job_id, choices = updated_job_ids, visible = True), gradio.Dropdown(visible = False) + + if job_action in [ 'job-submit', 'job-add-step' ]: + updated_job_ids = job_manager.find_job_ids('drafted') or [ 'none' ] + updated_job_id = selected_job_id if selected_job_id in updated_job_ids else get_last(updated_job_ids) + + return gradio.Textbox(visible = False), gradio.Dropdown(value = updated_job_id, choices = updated_job_ids, visible = True), gradio.Dropdown(visible = False) + + if job_action in [ 'job-remix-step', 'job-insert-step', 'job-remove-step' ]: + updated_job_ids = job_manager.find_job_ids('drafted') or [ 'none' ] + updated_job_id = selected_job_id if selected_job_id in updated_job_ids else get_last(updated_job_ids) + updated_step_choices = get_step_choices(updated_job_id) or [ 'none' ] #type:ignore[list-item] + + return gradio.Textbox(visible = False), gradio.Dropdown(value = updated_job_id, choices = updated_job_ids, visible = True), gradio.Dropdown(value = get_last(updated_step_choices), choices = updated_step_choices, visible = True) + return gradio.Textbox(visible = False), gradio.Dropdown(visible = False), gradio.Dropdown(visible = False) + + +def update_step_index(job_id : str) -> gradio.Dropdown: + step_choices = get_step_choices(job_id) or [ 'none' ] #type:ignore[list-item] + return gradio.Dropdown(value = get_last(step_choices), choices = step_choices) diff --git a/facefusion/uis/components/job_runner.py b/facefusion/uis/components/job_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..df69eb095d2b9ef57bdfb93c4acf0e46fb396f80 --- /dev/null +++ b/facefusion/uis/components/job_runner.py @@ -0,0 +1,142 @@ +from time import sleep +from typing import Optional, Tuple + +import gradio + +from facefusion import logger, process_manager, state_manager, wording +from facefusion.common_helper import get_first, get_last +from facefusion.core import process_step +from facefusion.jobs import job_manager, job_runner, job_store +from facefusion.types import UiWorkflow +from facefusion.uis import choices as uis_choices +from facefusion.uis.core import get_ui_component +from facefusion.uis.types import JobRunnerAction +from facefusion.uis.ui_helper import convert_str_none + +JOB_RUNNER_WRAPPER : Optional[gradio.Column] = None +JOB_RUNNER_JOB_ACTION_DROPDOWN : Optional[gradio.Dropdown] = None +JOB_RUNNER_JOB_ID_DROPDOWN : Optional[gradio.Dropdown] = None +JOB_RUNNER_START_BUTTON : Optional[gradio.Button] = None +JOB_RUNNER_STOP_BUTTON : Optional[gradio.Button] = None + + +def render() -> None: + global JOB_RUNNER_WRAPPER + global JOB_RUNNER_JOB_ACTION_DROPDOWN + global JOB_RUNNER_JOB_ID_DROPDOWN + global JOB_RUNNER_START_BUTTON + global JOB_RUNNER_STOP_BUTTON + + if job_manager.init_jobs(state_manager.get_item('jobs_path')): + is_job_runner = state_manager.get_item('ui_workflow') == 'job_runner' + queued_job_ids = job_manager.find_job_ids('queued') or [ 'none' ] + + with gradio.Column(visible = is_job_runner) as JOB_RUNNER_WRAPPER: + JOB_RUNNER_JOB_ACTION_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.job_runner_job_action_dropdown'), + choices = uis_choices.job_runner_actions, + value = get_first(uis_choices.job_runner_actions) + ) + JOB_RUNNER_JOB_ID_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.job_runner_job_id_dropdown'), + choices = queued_job_ids, + value = get_last(queued_job_ids) + ) + with gradio.Row(): + JOB_RUNNER_START_BUTTON = gradio.Button( + value = wording.get('uis.start_button'), + variant = 'primary', + size = 'sm' + ) + JOB_RUNNER_STOP_BUTTON = gradio.Button( + value = wording.get('uis.stop_button'), + variant = 'primary', + size = 'sm', + visible = False + ) + + +def listen() -> None: + JOB_RUNNER_JOB_ACTION_DROPDOWN.change(update_job_action, inputs = JOB_RUNNER_JOB_ACTION_DROPDOWN, outputs = JOB_RUNNER_JOB_ID_DROPDOWN) + JOB_RUNNER_START_BUTTON.click(start, outputs = [ JOB_RUNNER_START_BUTTON, JOB_RUNNER_STOP_BUTTON ]) + JOB_RUNNER_START_BUTTON.click(run, inputs = [ JOB_RUNNER_JOB_ACTION_DROPDOWN, JOB_RUNNER_JOB_ID_DROPDOWN ], outputs = [ JOB_RUNNER_START_BUTTON, JOB_RUNNER_STOP_BUTTON, JOB_RUNNER_JOB_ID_DROPDOWN ]) + JOB_RUNNER_STOP_BUTTON.click(stop, outputs = [ JOB_RUNNER_START_BUTTON, JOB_RUNNER_STOP_BUTTON ]) + + ui_workflow_dropdown = get_ui_component('ui_workflow_dropdown') + if ui_workflow_dropdown: + ui_workflow_dropdown.change(remote_update, inputs = ui_workflow_dropdown, outputs = [ JOB_RUNNER_WRAPPER, JOB_RUNNER_JOB_ACTION_DROPDOWN, JOB_RUNNER_JOB_ID_DROPDOWN ]) + + +def remote_update(ui_workflow : UiWorkflow) -> Tuple[gradio.Row, gradio.Dropdown, gradio.Dropdown]: + is_job_runner = ui_workflow == 'job_runner' + queued_job_ids = job_manager.find_job_ids('queued') or [ 'none' ] + + return gradio.Row(visible = is_job_runner), gradio.Dropdown(value = get_first(uis_choices.job_runner_actions), choices = uis_choices.job_runner_actions), gradio.Dropdown(value = get_last(queued_job_ids), choices = queued_job_ids) + + +def start() -> Tuple[gradio.Button, gradio.Button]: + while not process_manager.is_processing(): + sleep(0.5) + return gradio.Button(visible = False), gradio.Button(visible = True) + + +def run(job_action : JobRunnerAction, job_id : str) -> Tuple[gradio.Button, gradio.Button, gradio.Dropdown]: + job_id = convert_str_none(job_id) + + for key in job_store.get_job_keys(): + state_manager.sync_item(key) #type:ignore[arg-type] + + if job_action == 'job-run': + logger.info(wording.get('running_job').format(job_id = job_id), __name__) + if job_id and job_runner.run_job(job_id, process_step): + logger.info(wording.get('processing_job_succeed').format(job_id = job_id), __name__) + else: + logger.info(wording.get('processing_job_failed').format(job_id = job_id), __name__) + updated_job_ids = job_manager.find_job_ids('queued') or [ 'none' ] + + return gradio.Button(visible = True), gradio.Button(visible = False), gradio.Dropdown(value = get_last(updated_job_ids), choices = updated_job_ids) + + if job_action == 'job-run-all': + logger.info(wording.get('running_jobs'), __name__) + halt_on_error = False + if job_runner.run_jobs(process_step, halt_on_error): + logger.info(wording.get('processing_jobs_succeed'), __name__) + else: + logger.info(wording.get('processing_jobs_failed'), __name__) + + if job_action == 'job-retry': + logger.info(wording.get('retrying_job').format(job_id = job_id), __name__) + if job_id and job_runner.retry_job(job_id, process_step): + logger.info(wording.get('processing_job_succeed').format(job_id = job_id), __name__) + else: + logger.info(wording.get('processing_job_failed').format(job_id = job_id), __name__) + updated_job_ids = job_manager.find_job_ids('failed') or [ 'none' ] + + return gradio.Button(visible = True), gradio.Button(visible = False), gradio.Dropdown(value = get_last(updated_job_ids), choices = updated_job_ids) + + if job_action == 'job-retry-all': + logger.info(wording.get('retrying_jobs'), __name__) + halt_on_error = False + if job_runner.retry_jobs(process_step, halt_on_error): + logger.info(wording.get('processing_jobs_succeed'), __name__) + else: + logger.info(wording.get('processing_jobs_failed'), __name__) + return gradio.Button(visible = True), gradio.Button(visible = False), gradio.Dropdown() + + +def stop() -> Tuple[gradio.Button, gradio.Button]: + process_manager.stop() + return gradio.Button(visible = True), gradio.Button(visible = False) + + +def update_job_action(job_action : JobRunnerAction) -> gradio.Dropdown: + if job_action == 'job-run': + updated_job_ids = job_manager.find_job_ids('queued') or [ 'none' ] + + return gradio.Dropdown(value = get_last(updated_job_ids), choices = updated_job_ids, visible = True) + + if job_action == 'job-retry': + updated_job_ids = job_manager.find_job_ids('failed') or [ 'none' ] + + return gradio.Dropdown(value = get_last(updated_job_ids), choices = updated_job_ids, visible = True) + return gradio.Dropdown(visible = False) diff --git a/facefusion/uis/components/lip_syncer_options.py b/facefusion/uis/components/lip_syncer_options.py new file mode 100644 index 0000000000000000000000000000000000000000..e253ee91d5d149c3962634bea0f87935174aff31 --- /dev/null +++ b/facefusion/uis/components/lip_syncer_options.py @@ -0,0 +1,64 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import calc_float_step +from facefusion.processors import choices as processors_choices +from facefusion.processors.core import load_processor_module +from facefusion.processors.types import LipSyncerModel +from facefusion.uis.core import get_ui_component, register_ui_component + +LIP_SYNCER_MODEL_DROPDOWN : Optional[gradio.Dropdown] = None +LIP_SYNCER_WEIGHT_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global LIP_SYNCER_MODEL_DROPDOWN + global LIP_SYNCER_WEIGHT_SLIDER + + has_lip_syncer = 'lip_syncer' in state_manager.get_item('processors') + LIP_SYNCER_MODEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.lip_syncer_model_dropdown'), + choices = processors_choices.lip_syncer_models, + value = state_manager.get_item('lip_syncer_model'), + visible = has_lip_syncer + ) + LIP_SYNCER_WEIGHT_SLIDER = gradio.Slider( + label = wording.get('uis.lip_syncer_weight_slider'), + value = state_manager.get_item('lip_syncer_weight'), + step = calc_float_step(processors_choices.lip_syncer_weight_range), + minimum = processors_choices.lip_syncer_weight_range[0], + maximum = processors_choices.lip_syncer_weight_range[-1], + visible = has_lip_syncer + ) + register_ui_component('lip_syncer_model_dropdown', LIP_SYNCER_MODEL_DROPDOWN) + register_ui_component('lip_syncer_weight_slider', LIP_SYNCER_WEIGHT_SLIDER) + + +def listen() -> None: + LIP_SYNCER_MODEL_DROPDOWN.change(update_lip_syncer_model, inputs = LIP_SYNCER_MODEL_DROPDOWN, outputs = LIP_SYNCER_MODEL_DROPDOWN) + LIP_SYNCER_WEIGHT_SLIDER.release(update_lip_syncer_weight, inputs = LIP_SYNCER_WEIGHT_SLIDER) + + processors_checkbox_group = get_ui_component('processors_checkbox_group') + if processors_checkbox_group: + processors_checkbox_group.change(remote_update, inputs = processors_checkbox_group, outputs = [ LIP_SYNCER_MODEL_DROPDOWN, LIP_SYNCER_WEIGHT_SLIDER ]) + + +def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider]: + has_lip_syncer = 'lip_syncer' in processors + return gradio.Dropdown(visible = has_lip_syncer), gradio.Slider(visible = has_lip_syncer) + + +def update_lip_syncer_model(lip_syncer_model : LipSyncerModel) -> gradio.Dropdown: + lip_syncer_module = load_processor_module('lip_syncer') + lip_syncer_module.clear_inference_pool() + state_manager.set_item('lip_syncer_model', lip_syncer_model) + + if lip_syncer_module.pre_check(): + return gradio.Dropdown(value = state_manager.get_item('lip_syncer_model')) + return gradio.Dropdown() + + +def update_lip_syncer_weight(lip_syncer_weight : float) -> None: + state_manager.set_item('lip_syncer_weight', lip_syncer_weight) diff --git a/facefusion/uis/components/memory.py b/facefusion/uis/components/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..81c367a2163463f92601587e3f600711f76fc27f --- /dev/null +++ b/facefusion/uis/components/memory.py @@ -0,0 +1,42 @@ +from typing import Optional + +import gradio + +import facefusion.choices +from facefusion import state_manager, wording +from facefusion.common_helper import calc_int_step +from facefusion.types import VideoMemoryStrategy + +VIDEO_MEMORY_STRATEGY_DROPDOWN : Optional[gradio.Dropdown] = None +SYSTEM_MEMORY_LIMIT_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global VIDEO_MEMORY_STRATEGY_DROPDOWN + global SYSTEM_MEMORY_LIMIT_SLIDER + + VIDEO_MEMORY_STRATEGY_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.video_memory_strategy_dropdown'), + choices = facefusion.choices.video_memory_strategies, + value = state_manager.get_item('video_memory_strategy') + ) + SYSTEM_MEMORY_LIMIT_SLIDER = gradio.Slider( + label = wording.get('uis.system_memory_limit_slider'), + step = calc_int_step(facefusion.choices.system_memory_limit_range), + minimum = facefusion.choices.system_memory_limit_range[0], + maximum = facefusion.choices.system_memory_limit_range[-1], + value = state_manager.get_item('system_memory_limit') + ) + + +def listen() -> None: + VIDEO_MEMORY_STRATEGY_DROPDOWN.change(update_video_memory_strategy, inputs = VIDEO_MEMORY_STRATEGY_DROPDOWN) + SYSTEM_MEMORY_LIMIT_SLIDER.release(update_system_memory_limit, inputs = SYSTEM_MEMORY_LIMIT_SLIDER) + + +def update_video_memory_strategy(video_memory_strategy : VideoMemoryStrategy) -> None: + state_manager.set_item('video_memory_strategy', video_memory_strategy) + + +def update_system_memory_limit(system_memory_limit : float) -> None: + state_manager.set_item('system_memory_limit', int(system_memory_limit)) diff --git a/facefusion/uis/components/output.py b/facefusion/uis/components/output.py new file mode 100644 index 0000000000000000000000000000000000000000..84fd08915d3cef41144c876cc64f3011caf55dcf --- /dev/null +++ b/facefusion/uis/components/output.py @@ -0,0 +1,42 @@ +import tempfile +from typing import Optional + +import gradio + +from facefusion import state_manager, wording +from facefusion.uis.core import register_ui_component + +OUTPUT_PATH_TEXTBOX : Optional[gradio.Textbox] = None +OUTPUT_IMAGE : Optional[gradio.Image] = None +OUTPUT_VIDEO : Optional[gradio.Video] = None + + +def render() -> None: + global OUTPUT_PATH_TEXTBOX + global OUTPUT_IMAGE + global OUTPUT_VIDEO + + if not state_manager.get_item('output_path'): + state_manager.set_item('output_path', tempfile.gettempdir()) + OUTPUT_PATH_TEXTBOX = gradio.Textbox( + label = wording.get('uis.output_path_textbox'), + value = state_manager.get_item('output_path'), + max_lines = 1 + ) + OUTPUT_IMAGE = gradio.Image( + label = wording.get('uis.output_image_or_video'), + visible = False + ) + OUTPUT_VIDEO = gradio.Video( + label = wording.get('uis.output_image_or_video') + ) + + +def listen() -> None: + OUTPUT_PATH_TEXTBOX.change(update_output_path, inputs = OUTPUT_PATH_TEXTBOX) + register_ui_component('output_image', OUTPUT_IMAGE) + register_ui_component('output_video', OUTPUT_VIDEO) + + +def update_output_path(output_path : str) -> None: + state_manager.set_item('output_path', output_path) diff --git a/facefusion/uis/components/output_options.py b/facefusion/uis/components/output_options.py new file mode 100644 index 0000000000000000000000000000000000000000..46b875da3bc08c75581f87b7bd00bc30876e08af --- /dev/null +++ b/facefusion/uis/components/output_options.py @@ -0,0 +1,193 @@ +from typing import Optional, Tuple + +import gradio + +import facefusion.choices +from facefusion import state_manager, wording +from facefusion.common_helper import calc_int_step +from facefusion.ffmpeg import get_available_encoder_set +from facefusion.filesystem import is_image, is_video +from facefusion.types import AudioEncoder, Fps, VideoEncoder, VideoPreset +from facefusion.uis.core import get_ui_components, register_ui_component +from facefusion.vision import create_image_resolutions, create_video_resolutions, detect_image_resolution, detect_video_fps, detect_video_resolution, pack_resolution + +OUTPUT_IMAGE_QUALITY_SLIDER : Optional[gradio.Slider] = None +OUTPUT_IMAGE_RESOLUTION_DROPDOWN : Optional[gradio.Dropdown] = None +OUTPUT_AUDIO_ENCODER_DROPDOWN : Optional[gradio.Dropdown] = None +OUTPUT_AUDIO_QUALITY_SLIDER : Optional[gradio.Slider] = None +OUTPUT_AUDIO_VOLUME_SLIDER : Optional[gradio.Slider] = None +OUTPUT_VIDEO_ENCODER_DROPDOWN : Optional[gradio.Dropdown] = None +OUTPUT_VIDEO_PRESET_DROPDOWN : Optional[gradio.Dropdown] = None +OUTPUT_VIDEO_RESOLUTION_DROPDOWN : Optional[gradio.Dropdown] = None +OUTPUT_VIDEO_QUALITY_SLIDER : Optional[gradio.Slider] = None +OUTPUT_VIDEO_FPS_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global OUTPUT_IMAGE_QUALITY_SLIDER + global OUTPUT_IMAGE_RESOLUTION_DROPDOWN + global OUTPUT_AUDIO_ENCODER_DROPDOWN + global OUTPUT_AUDIO_QUALITY_SLIDER + global OUTPUT_AUDIO_VOLUME_SLIDER + global OUTPUT_VIDEO_ENCODER_DROPDOWN + global OUTPUT_VIDEO_PRESET_DROPDOWN + global OUTPUT_VIDEO_RESOLUTION_DROPDOWN + global OUTPUT_VIDEO_QUALITY_SLIDER + global OUTPUT_VIDEO_FPS_SLIDER + + output_image_resolutions = [] + output_video_resolutions = [] + available_encoder_set = get_available_encoder_set() + if is_image(state_manager.get_item('target_path')): + output_image_resolution = detect_image_resolution(state_manager.get_item('target_path')) + output_image_resolutions = create_image_resolutions(output_image_resolution) + if is_video(state_manager.get_item('target_path')): + output_video_resolution = detect_video_resolution(state_manager.get_item('target_path')) + output_video_resolutions = create_video_resolutions(output_video_resolution) + OUTPUT_IMAGE_QUALITY_SLIDER = gradio.Slider( + label = wording.get('uis.output_image_quality_slider'), + value = state_manager.get_item('output_image_quality'), + step = calc_int_step(facefusion.choices.output_image_quality_range), + minimum = facefusion.choices.output_image_quality_range[0], + maximum = facefusion.choices.output_image_quality_range[-1], + visible = is_image(state_manager.get_item('target_path')) + ) + OUTPUT_IMAGE_RESOLUTION_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.output_image_resolution_dropdown'), + choices = output_image_resolutions, + value = state_manager.get_item('output_image_resolution'), + visible = is_image(state_manager.get_item('target_path')) + ) + OUTPUT_AUDIO_ENCODER_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.output_audio_encoder_dropdown'), + choices = available_encoder_set.get('audio'), + value = state_manager.get_item('output_audio_encoder'), + visible = is_video(state_manager.get_item('target_path')) + ) + OUTPUT_AUDIO_QUALITY_SLIDER = gradio.Slider( + label = wording.get('uis.output_audio_quality_slider'), + value = state_manager.get_item('output_audio_quality'), + step = calc_int_step(facefusion.choices.output_audio_quality_range), + minimum = facefusion.choices.output_audio_quality_range[0], + maximum = facefusion.choices.output_audio_quality_range[-1], + visible = is_video(state_manager.get_item('target_path')) + ) + OUTPUT_AUDIO_VOLUME_SLIDER = gradio.Slider( + label = wording.get('uis.output_audio_volume_slider'), + value = state_manager.get_item('output_audio_volume'), + step = calc_int_step(facefusion.choices.output_audio_volume_range), + minimum = facefusion.choices.output_audio_volume_range[0], + maximum = facefusion.choices.output_audio_volume_range[-1], + visible = is_video(state_manager.get_item('target_path')) + ) + OUTPUT_VIDEO_ENCODER_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.output_video_encoder_dropdown'), + choices = available_encoder_set.get('video'), + value = state_manager.get_item('output_video_encoder'), + visible = is_video(state_manager.get_item('target_path')) + ) + OUTPUT_VIDEO_PRESET_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.output_video_preset_dropdown'), + choices = facefusion.choices.output_video_presets, + value = state_manager.get_item('output_video_preset'), + visible = is_video(state_manager.get_item('target_path')) + ) + OUTPUT_VIDEO_QUALITY_SLIDER = gradio.Slider( + label = wording.get('uis.output_video_quality_slider'), + value = state_manager.get_item('output_video_quality'), + step = calc_int_step(facefusion.choices.output_video_quality_range), + minimum = facefusion.choices.output_video_quality_range[0], + maximum = facefusion.choices.output_video_quality_range[-1], + visible = is_video(state_manager.get_item('target_path')) + ) + OUTPUT_VIDEO_RESOLUTION_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.output_video_resolution_dropdown'), + choices = output_video_resolutions, + value = state_manager.get_item('output_video_resolution'), + visible = is_video(state_manager.get_item('target_path')) + ) + OUTPUT_VIDEO_FPS_SLIDER = gradio.Slider( + label = wording.get('uis.output_video_fps_slider'), + value = state_manager.get_item('output_video_fps'), + step = 0.01, + minimum = 1, + maximum = 60, + visible = is_video(state_manager.get_item('target_path')) + ) + register_ui_component('output_video_fps_slider', OUTPUT_VIDEO_FPS_SLIDER) + + +def listen() -> None: + OUTPUT_IMAGE_QUALITY_SLIDER.release(update_output_image_quality, inputs = OUTPUT_IMAGE_QUALITY_SLIDER) + OUTPUT_IMAGE_RESOLUTION_DROPDOWN.change(update_output_image_resolution, inputs = OUTPUT_IMAGE_RESOLUTION_DROPDOWN) + OUTPUT_AUDIO_ENCODER_DROPDOWN.change(update_output_audio_encoder, inputs = OUTPUT_AUDIO_ENCODER_DROPDOWN) + OUTPUT_AUDIO_QUALITY_SLIDER.release(update_output_audio_quality, inputs = OUTPUT_AUDIO_QUALITY_SLIDER) + OUTPUT_AUDIO_VOLUME_SLIDER.release(update_output_audio_volume, inputs = OUTPUT_AUDIO_VOLUME_SLIDER) + OUTPUT_VIDEO_ENCODER_DROPDOWN.change(update_output_video_encoder, inputs = OUTPUT_VIDEO_ENCODER_DROPDOWN) + OUTPUT_VIDEO_PRESET_DROPDOWN.change(update_output_video_preset, inputs = OUTPUT_VIDEO_PRESET_DROPDOWN) + OUTPUT_VIDEO_QUALITY_SLIDER.release(update_output_video_quality, inputs = OUTPUT_VIDEO_QUALITY_SLIDER) + OUTPUT_VIDEO_RESOLUTION_DROPDOWN.change(update_output_video_resolution, inputs = OUTPUT_VIDEO_RESOLUTION_DROPDOWN) + OUTPUT_VIDEO_FPS_SLIDER.release(update_output_video_fps, inputs = OUTPUT_VIDEO_FPS_SLIDER) + + for ui_component in get_ui_components( + [ + 'target_image', + 'target_video' + ]): + for method in [ 'change', 'clear' ]: + getattr(ui_component, method)(remote_update, outputs = [ OUTPUT_IMAGE_QUALITY_SLIDER, OUTPUT_IMAGE_RESOLUTION_DROPDOWN, OUTPUT_AUDIO_ENCODER_DROPDOWN, OUTPUT_AUDIO_QUALITY_SLIDER, OUTPUT_AUDIO_VOLUME_SLIDER, OUTPUT_VIDEO_ENCODER_DROPDOWN, OUTPUT_VIDEO_PRESET_DROPDOWN, OUTPUT_VIDEO_QUALITY_SLIDER, OUTPUT_VIDEO_RESOLUTION_DROPDOWN, OUTPUT_VIDEO_FPS_SLIDER ]) + + +def remote_update() -> Tuple[gradio.Slider, gradio.Dropdown, gradio.Dropdown, gradio.Slider, gradio.Slider, gradio.Dropdown, gradio.Dropdown, gradio.Slider, gradio.Dropdown, gradio.Slider]: + if is_image(state_manager.get_item('target_path')): + output_image_resolution = detect_image_resolution(state_manager.get_item('target_path')) + output_image_resolutions = create_image_resolutions(output_image_resolution) + state_manager.set_item('output_image_resolution', pack_resolution(output_image_resolution)) + return gradio.Slider(visible = True), gradio.Dropdown(value = state_manager.get_item('output_image_resolution'), choices = output_image_resolutions, visible = True), gradio.Dropdown(visible = False), gradio.Slider(visible = False), gradio.Slider(visible = False), gradio.Dropdown(visible = False), gradio.Dropdown(visible = False), gradio.Slider(visible = False), gradio.Dropdown(visible = False), gradio.Slider(visible = False) + if is_video(state_manager.get_item('target_path')): + output_video_resolution = detect_video_resolution(state_manager.get_item('target_path')) + output_video_resolutions = create_video_resolutions(output_video_resolution) + state_manager.set_item('output_video_resolution', pack_resolution(output_video_resolution)) + state_manager.set_item('output_video_fps', detect_video_fps(state_manager.get_item('target_path'))) + return gradio.Slider(visible = False), gradio.Dropdown(visible = False), gradio.Dropdown(visible = True), gradio.Slider(visible = True), gradio.Slider(visible = True), gradio.Dropdown(visible = True), gradio.Dropdown(visible = True), gradio.Slider(visible = True), gradio.Dropdown(value = state_manager.get_item('output_video_resolution'), choices = output_video_resolutions, visible = True), gradio.Slider(value = state_manager.get_item('output_video_fps'), visible = True) + return gradio.Slider(visible = False), gradio.Dropdown(visible = False), gradio.Dropdown(visible = False), gradio.Slider(visible = False), gradio.Slider(visible = False), gradio.Dropdown(visible = False), gradio.Dropdown(visible = False), gradio.Slider(visible = False), gradio.Dropdown(visible = False), gradio.Slider(visible = False) + + +def update_output_image_quality(output_image_quality : float) -> None: + state_manager.set_item('output_image_quality', int(output_image_quality)) + + +def update_output_image_resolution(output_image_resolution : str) -> None: + state_manager.set_item('output_image_resolution', output_image_resolution) + + +def update_output_audio_encoder(output_audio_encoder : AudioEncoder) -> None: + state_manager.set_item('output_audio_encoder', output_audio_encoder) + + +def update_output_audio_quality(output_audio_quality : float) -> None: + state_manager.set_item('output_audio_quality', int(output_audio_quality)) + + +def update_output_audio_volume(output_audio_volume: float) -> None: + state_manager.set_item('output_audio_volume', int(output_audio_volume)) + + +def update_output_video_encoder(output_video_encoder : VideoEncoder) -> None: + state_manager.set_item('output_video_encoder', output_video_encoder) + + +def update_output_video_preset(output_video_preset : VideoPreset) -> None: + state_manager.set_item('output_video_preset', output_video_preset) + + +def update_output_video_quality(output_video_quality : float) -> None: + state_manager.set_item('output_video_quality', int(output_video_quality)) + + +def update_output_video_resolution(output_video_resolution : str) -> None: + state_manager.set_item('output_video_resolution', output_video_resolution) + + +def update_output_video_fps(output_video_fps : Fps) -> None: + state_manager.set_item('output_video_fps', output_video_fps) diff --git a/facefusion/uis/components/preview.py b/facefusion/uis/components/preview.py new file mode 100644 index 0000000000000000000000000000000000000000..d6283eff3967db1c5a2d985d1d11f6ac354fefd0 --- /dev/null +++ b/facefusion/uis/components/preview.py @@ -0,0 +1,259 @@ +from time import sleep +from typing import Optional + +import cv2 +import gradio +import numpy + +from facefusion import logger, process_manager, state_manager, wording +from facefusion.audio import create_empty_audio_frame, get_audio_frame +from facefusion.common_helper import get_first +from facefusion.content_analyser import analyse_frame +from facefusion.core import conditional_append_reference_faces +from facefusion.face_analyser import get_average_face, get_many_faces +from facefusion.face_selector import sort_faces_by_order +from facefusion.face_store import clear_reference_faces, clear_static_faces, get_reference_faces +from facefusion.filesystem import filter_audio_paths, is_image, is_video +from facefusion.processors.core import get_processors_modules +from facefusion.types import AudioFrame, Face, FaceSet, VisionFrame +from facefusion.uis.core import get_ui_component, get_ui_components, register_ui_component +from facefusion.uis.types import ComponentOptions +from facefusion.vision import count_video_frame_total, detect_frame_orientation, normalize_frame_color, read_static_image, read_static_images, read_video_frame, restrict_frame + +PREVIEW_IMAGE : Optional[gradio.Image] = None +PREVIEW_FRAME_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global PREVIEW_IMAGE + global PREVIEW_FRAME_SLIDER + + preview_image_options : ComponentOptions =\ + { + 'label': wording.get('uis.preview_image') + } + preview_frame_slider_options : ComponentOptions =\ + { + 'label': wording.get('uis.preview_frame_slider'), + 'step': 1, + 'minimum': 0, + 'maximum': 100, + 'visible': False + } + conditional_append_reference_faces() + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + source_frames = read_static_images(state_manager.get_item('source_paths')) + source_faces = get_many_faces(source_frames) + source_face = get_average_face(source_faces) + source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths'))) + source_audio_frame = create_empty_audio_frame() + + if source_audio_path and state_manager.get_item('output_video_fps') and state_manager.get_item('reference_frame_number'): + temp_audio_frame = get_audio_frame(source_audio_path, state_manager.get_item('output_video_fps'), state_manager.get_item('reference_frame_number')) + if numpy.any(temp_audio_frame): + source_audio_frame = temp_audio_frame + + if is_image(state_manager.get_item('target_path')): + target_vision_frame = read_static_image(state_manager.get_item('target_path')) + preview_vision_frame = process_preview_frame(reference_faces, source_face, source_audio_frame, target_vision_frame) + preview_image_options['value'] = normalize_frame_color(preview_vision_frame) + preview_image_options['elem_classes'] = [ 'image-preview', 'is-' + detect_frame_orientation(preview_vision_frame) ] + + if is_video(state_manager.get_item('target_path')): + temp_vision_frame = read_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number')) + preview_vision_frame = process_preview_frame(reference_faces, source_face, source_audio_frame, temp_vision_frame) + preview_image_options['value'] = normalize_frame_color(preview_vision_frame) + preview_image_options['elem_classes'] = [ 'image-preview', 'is-' + detect_frame_orientation(preview_vision_frame) ] + preview_image_options['visible'] = True + preview_frame_slider_options['value'] = state_manager.get_item('reference_frame_number') + preview_frame_slider_options['maximum'] = count_video_frame_total(state_manager.get_item('target_path')) + preview_frame_slider_options['visible'] = True + PREVIEW_IMAGE = gradio.Image(**preview_image_options) + PREVIEW_FRAME_SLIDER = gradio.Slider(**preview_frame_slider_options) + register_ui_component('preview_frame_slider', PREVIEW_FRAME_SLIDER) + + +def listen() -> None: + PREVIEW_FRAME_SLIDER.release(update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE, show_progress = 'hidden') + PREVIEW_FRAME_SLIDER.change(slide_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE, show_progress = 'hidden', trigger_mode = 'once') + + reference_face_position_gallery = get_ui_component('reference_face_position_gallery') + if reference_face_position_gallery: + reference_face_position_gallery.select(clear_and_update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE) + + for ui_component in get_ui_components( + [ + 'source_audio', + 'source_image', + 'target_image', + 'target_video' + ]): + for method in [ 'change', 'clear' ]: + getattr(ui_component, method)(update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE) + + for ui_component in get_ui_components( + [ + 'target_image', + 'target_video' + ]): + for method in [ 'change', 'clear' ]: + getattr(ui_component, method)(update_preview_frame_slider, outputs = PREVIEW_FRAME_SLIDER) + + for ui_component in get_ui_components( + [ + 'face_debugger_items_checkbox_group', + 'frame_colorizer_size_dropdown', + 'face_mask_types_checkbox_group', + 'face_mask_areas_checkbox_group', + 'face_mask_regions_checkbox_group' + ]): + ui_component.change(update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE) + + for ui_component in get_ui_components( + [ + 'age_modifier_direction_slider', + 'deep_swapper_morph_slider', + 'expression_restorer_factor_slider', + 'face_editor_eyebrow_direction_slider', + 'face_editor_eye_gaze_horizontal_slider', + 'face_editor_eye_gaze_vertical_slider', + 'face_editor_eye_open_ratio_slider', + 'face_editor_lip_open_ratio_slider', + 'face_editor_mouth_grim_slider', + 'face_editor_mouth_pout_slider', + 'face_editor_mouth_purse_slider', + 'face_editor_mouth_smile_slider', + 'face_editor_mouth_position_horizontal_slider', + 'face_editor_mouth_position_vertical_slider', + 'face_editor_head_pitch_slider', + 'face_editor_head_yaw_slider', + 'face_editor_head_roll_slider', + 'face_enhancer_blend_slider', + 'face_enhancer_weight_slider', + 'frame_colorizer_blend_slider', + 'frame_enhancer_blend_slider', + 'lip_syncer_weight_slider', + 'reference_face_distance_slider', + 'face_selector_age_range_slider', + 'face_mask_blur_slider', + 'face_mask_padding_top_slider', + 'face_mask_padding_bottom_slider', + 'face_mask_padding_left_slider', + 'face_mask_padding_right_slider', + 'output_video_fps_slider' + ]): + ui_component.release(update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE) + + for ui_component in get_ui_components( + [ + 'age_modifier_model_dropdown', + 'deep_swapper_model_dropdown', + 'expression_restorer_model_dropdown', + 'processors_checkbox_group', + 'face_editor_model_dropdown', + 'face_enhancer_model_dropdown', + 'face_swapper_model_dropdown', + 'face_swapper_pixel_boost_dropdown', + 'frame_colorizer_model_dropdown', + 'frame_enhancer_model_dropdown', + 'lip_syncer_model_dropdown', + 'face_selector_mode_dropdown', + 'face_selector_order_dropdown', + 'face_selector_gender_dropdown', + 'face_selector_race_dropdown', + 'face_detector_model_dropdown', + 'face_detector_size_dropdown', + 'face_detector_angles_checkbox_group', + 'face_landmarker_model_dropdown', + 'face_occluder_model_dropdown', + 'face_parser_model_dropdown' + ]): + ui_component.change(clear_and_update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE) + + for ui_component in get_ui_components( + [ + 'face_detector_score_slider', + 'face_landmarker_score_slider' + ]): + ui_component.release(clear_and_update_preview_image, inputs = PREVIEW_FRAME_SLIDER, outputs = PREVIEW_IMAGE) + + +def clear_and_update_preview_image(frame_number : int = 0) -> gradio.Image: + clear_reference_faces() + clear_static_faces() + return update_preview_image(frame_number) + + +def slide_preview_image(frame_number : int = 0) -> gradio.Image: + if is_video(state_manager.get_item('target_path')): + preview_vision_frame = normalize_frame_color(read_video_frame(state_manager.get_item('target_path'), frame_number)) + preview_vision_frame = restrict_frame(preview_vision_frame, (1024, 1024)) + return gradio.Image(value = preview_vision_frame) + return gradio.Image(value = None) + + +def update_preview_image(frame_number : int = 0) -> gradio.Image: + while process_manager.is_checking(): + sleep(0.5) + conditional_append_reference_faces() + reference_faces = get_reference_faces() if 'reference' in state_manager.get_item('face_selector_mode') else None + source_frames = read_static_images(state_manager.get_item('source_paths')) + source_faces = [] + + for source_frame in source_frames: + temp_faces = get_many_faces([ source_frame ]) + temp_faces = sort_faces_by_order(temp_faces, 'large-small') + if temp_faces: + source_faces.append(get_first(temp_faces)) + source_face = get_average_face(source_faces) + source_audio_path = get_first(filter_audio_paths(state_manager.get_item('source_paths'))) + source_audio_frame = create_empty_audio_frame() + + if source_audio_path and state_manager.get_item('output_video_fps') and state_manager.get_item('reference_frame_number'): + reference_audio_frame_number = state_manager.get_item('reference_frame_number') + if state_manager.get_item('trim_frame_start'): + reference_audio_frame_number -= state_manager.get_item('trim_frame_start') + temp_audio_frame = get_audio_frame(source_audio_path, state_manager.get_item('output_video_fps'), reference_audio_frame_number) + if numpy.any(temp_audio_frame): + source_audio_frame = temp_audio_frame + + if is_image(state_manager.get_item('target_path')): + target_vision_frame = read_static_image(state_manager.get_item('target_path')) + preview_vision_frame = process_preview_frame(reference_faces, source_face, source_audio_frame, target_vision_frame) + preview_vision_frame = normalize_frame_color(preview_vision_frame) + return gradio.Image(value = preview_vision_frame, elem_classes = [ 'image-preview', 'is-' + detect_frame_orientation(preview_vision_frame) ]) + + if is_video(state_manager.get_item('target_path')): + temp_vision_frame = read_video_frame(state_manager.get_item('target_path'), frame_number) + preview_vision_frame = process_preview_frame(reference_faces, source_face, source_audio_frame, temp_vision_frame) + preview_vision_frame = normalize_frame_color(preview_vision_frame) + return gradio.Image(value = preview_vision_frame, elem_classes = [ 'image-preview', 'is-' + detect_frame_orientation(preview_vision_frame) ]) + return gradio.Image(value = None, elem_classes = None) + + +def update_preview_frame_slider() -> gradio.Slider: + if is_video(state_manager.get_item('target_path')): + video_frame_total = count_video_frame_total(state_manager.get_item('target_path')) + return gradio.Slider(maximum = video_frame_total, visible = True) + return gradio.Slider(value = 0, visible = False) + + +def process_preview_frame(reference_faces : FaceSet, source_face : Face, source_audio_frame : AudioFrame, target_vision_frame : VisionFrame) -> VisionFrame: + target_vision_frame = restrict_frame(target_vision_frame, (1024, 1024)) + source_vision_frame = target_vision_frame.copy() + if analyse_frame(target_vision_frame): + return cv2.GaussianBlur(target_vision_frame, (99, 99), 0) + + for processor_module in get_processors_modules(state_manager.get_item('processors')): + logger.disable() + if processor_module.pre_process('preview'): + target_vision_frame = processor_module.process_frame( + { + 'reference_faces': reference_faces, + 'source_face': source_face, + 'source_audio_frame': source_audio_frame, + 'source_vision_frame': source_vision_frame, + 'target_vision_frame': target_vision_frame + }) + logger.enable() + return target_vision_frame diff --git a/facefusion/uis/components/processors.py b/facefusion/uis/components/processors.py new file mode 100644 index 0000000000000000000000000000000000000000..f734a0d5c0a5eab5164a667aa697fb273fc07ad6 --- /dev/null +++ b/facefusion/uis/components/processors.py @@ -0,0 +1,49 @@ +from typing import List, Optional + +import gradio + +from facefusion import state_manager, wording +from facefusion.filesystem import get_file_name, resolve_file_paths +from facefusion.processors.core import get_processors_modules +from facefusion.uis.core import register_ui_component + +PROCESSORS_CHECKBOX_GROUP : Optional[gradio.CheckboxGroup] = None + + +def render() -> None: + global PROCESSORS_CHECKBOX_GROUP + + PROCESSORS_CHECKBOX_GROUP = gradio.CheckboxGroup( + label = wording.get('uis.processors_checkbox_group'), + choices = sort_processors(state_manager.get_item('processors')), + value = state_manager.get_item('processors') + ) + register_ui_component('processors_checkbox_group', PROCESSORS_CHECKBOX_GROUP) + + +def listen() -> None: + PROCESSORS_CHECKBOX_GROUP.change(update_processors, inputs = PROCESSORS_CHECKBOX_GROUP, outputs = PROCESSORS_CHECKBOX_GROUP) + + +def update_processors(processors : List[str]) -> gradio.CheckboxGroup: + for processor_module in get_processors_modules(state_manager.get_item('processors')): + if hasattr(processor_module, 'clear_inference_pool'): + processor_module.clear_inference_pool() + + for processor_module in get_processors_modules(processors): + if not processor_module.pre_check(): + return gradio.CheckboxGroup() + + state_manager.set_item('processors', processors) + return gradio.CheckboxGroup(value = state_manager.get_item('processors'), choices = sort_processors(state_manager.get_item('processors'))) + + +def sort_processors(processors : List[str]) -> List[str]: + available_processors = [ get_file_name(file_path) for file_path in resolve_file_paths('facefusion/processors/modules') ] + current_processors = [] + + for processor in processors + available_processors: + if processor in available_processors and processor not in current_processors: + current_processors.append(processor) + + return current_processors diff --git a/facefusion/uis/components/source.py b/facefusion/uis/components/source.py new file mode 100644 index 0000000000000000000000000000000000000000..54ed2f5fec3b53e6d678d9d88887679b58ec4fe9 --- /dev/null +++ b/facefusion/uis/components/source.py @@ -0,0 +1,61 @@ +from typing import List, Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.common_helper import get_first +from facefusion.filesystem import filter_audio_paths, filter_image_paths, has_audio, has_image +from facefusion.uis.core import register_ui_component +from facefusion.uis.types import File + +SOURCE_FILE : Optional[gradio.File] = None +SOURCE_AUDIO : Optional[gradio.Audio] = None +SOURCE_IMAGE : Optional[gradio.Image] = None + + +def render() -> None: + global SOURCE_FILE + global SOURCE_AUDIO + global SOURCE_IMAGE + + has_source_audio = has_audio(state_manager.get_item('source_paths')) + has_source_image = has_image(state_manager.get_item('source_paths')) + SOURCE_FILE = gradio.File( + label = wording.get('uis.source_file'), + file_count = 'multiple', + value = state_manager.get_item('source_paths') if has_source_audio or has_source_image else None + ) + source_file_names = [ source_file_value.get('path') for source_file_value in SOURCE_FILE.value ] if SOURCE_FILE.value else None + source_audio_path = get_first(filter_audio_paths(source_file_names)) + source_image_path = get_first(filter_image_paths(source_file_names)) + SOURCE_AUDIO = gradio.Audio( + value = source_audio_path if has_source_audio else None, + visible = has_source_audio, + show_label = False + ) + SOURCE_IMAGE = gradio.Image( + value = source_image_path if has_source_image else None, + visible = has_source_image, + show_label = False + ) + register_ui_component('source_audio', SOURCE_AUDIO) + register_ui_component('source_image', SOURCE_IMAGE) + + +def listen() -> None: + SOURCE_FILE.change(update, inputs = SOURCE_FILE, outputs = [ SOURCE_AUDIO, SOURCE_IMAGE ]) + + +def update(files : List[File]) -> Tuple[gradio.Audio, gradio.Image]: + file_names = [ file.name for file in files ] if files else None + has_source_audio = has_audio(file_names) + has_source_image = has_image(file_names) + + if has_source_audio or has_source_image: + source_audio_path = get_first(filter_audio_paths(file_names)) + source_image_path = get_first(filter_image_paths(file_names)) + state_manager.set_item('source_paths', file_names) + return gradio.Audio(value = source_audio_path, visible = has_source_audio), gradio.Image(value = source_image_path, visible = has_source_image) + + state_manager.clear_item('source_paths') + return gradio.Audio(value = None, visible = False), gradio.Image(value = None, visible = False) diff --git a/facefusion/uis/components/target.py b/facefusion/uis/components/target.py new file mode 100644 index 0000000000000000000000000000000000000000..79e0f382d75ddd5a2c04edb6aa5df1accb336ce3 --- /dev/null +++ b/facefusion/uis/components/target.py @@ -0,0 +1,66 @@ +from typing import Optional, Tuple + +import gradio + +from facefusion import state_manager, wording +from facefusion.face_store import clear_reference_faces, clear_static_faces +from facefusion.filesystem import is_image, is_video +from facefusion.uis.core import register_ui_component +from facefusion.uis.types import ComponentOptions, File + +TARGET_FILE : Optional[gradio.File] = None +TARGET_IMAGE : Optional[gradio.Image] = None +TARGET_VIDEO : Optional[gradio.Video] = None + + +def render() -> None: + global TARGET_FILE + global TARGET_IMAGE + global TARGET_VIDEO + + is_target_image = is_image(state_manager.get_item('target_path')) + is_target_video = is_video(state_manager.get_item('target_path')) + TARGET_FILE = gradio.File( + label = wording.get('uis.target_file'), + value = state_manager.get_item('target_path') if is_target_image or is_target_video else None + ) + target_image_options : ComponentOptions =\ + { + 'show_label': False, + 'visible': False + } + target_video_options : ComponentOptions =\ + { + 'show_label': False, + 'visible': False + } + if is_target_image: + target_image_options['value'] = TARGET_FILE.value.get('path') + target_image_options['visible'] = True + if is_target_video: + target_video_options['value'] = TARGET_FILE.value.get('path') + target_video_options['visible'] = True + TARGET_IMAGE = gradio.Image(**target_image_options) + TARGET_VIDEO = gradio.Video(**target_video_options) + register_ui_component('target_image', TARGET_IMAGE) + register_ui_component('target_video', TARGET_VIDEO) + + +def listen() -> None: + TARGET_FILE.change(update, inputs = TARGET_FILE, outputs = [ TARGET_IMAGE, TARGET_VIDEO ]) + + +def update(file : File) -> Tuple[gradio.Image, gradio.Video]: + clear_reference_faces() + clear_static_faces() + + if file and is_image(file.name): + state_manager.set_item('target_path', file.name) + return gradio.Image(value = file.name, visible = True), gradio.Video(value = None, visible = False) + + if file and is_video(file.name): + state_manager.set_item('target_path', file.name) + return gradio.Image(value = None, visible = False), gradio.Video(value = file.name, visible = True) + + state_manager.clear_item('target_path') + return gradio.Image(value = None, visible = False), gradio.Video(value = None, visible = False) diff --git a/facefusion/uis/components/temp_frame.py b/facefusion/uis/components/temp_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6d60f0aa5883b106387843c4da1113a2579acf --- /dev/null +++ b/facefusion/uis/components/temp_frame.py @@ -0,0 +1,42 @@ +from typing import Optional + +import gradio + +import facefusion.choices +from facefusion import state_manager, wording +from facefusion.filesystem import is_video +from facefusion.types import TempFrameFormat +from facefusion.uis.core import get_ui_component + +TEMP_FRAME_FORMAT_DROPDOWN : Optional[gradio.Dropdown] = None + + +def render() -> None: + global TEMP_FRAME_FORMAT_DROPDOWN + + TEMP_FRAME_FORMAT_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.temp_frame_format_dropdown'), + choices = facefusion.choices.temp_frame_formats, + value = state_manager.get_item('temp_frame_format'), + visible = is_video(state_manager.get_item('target_path')) + ) + + +def listen() -> None: + TEMP_FRAME_FORMAT_DROPDOWN.change(update_temp_frame_format, inputs = TEMP_FRAME_FORMAT_DROPDOWN) + + target_video = get_ui_component('target_video') + if target_video: + for method in [ 'change', 'clear' ]: + getattr(target_video, method)(remote_update, outputs = TEMP_FRAME_FORMAT_DROPDOWN) + + +def remote_update() -> gradio.Dropdown: + if is_video(state_manager.get_item('target_path')): + return gradio.Dropdown(visible = True) + return gradio.Dropdown(visible = False) + + +def update_temp_frame_format(temp_frame_format : TempFrameFormat) -> None: + state_manager.set_item('temp_frame_format', temp_frame_format) + diff --git a/facefusion/uis/components/terminal.py b/facefusion/uis/components/terminal.py new file mode 100644 index 0000000000000000000000000000000000000000..8ba8ff855db88029728739a7e0c64eef2eccf250 --- /dev/null +++ b/facefusion/uis/components/terminal.py @@ -0,0 +1,80 @@ +import io +import logging +import math +import os +from typing import Optional + +import gradio +from tqdm import tqdm + +import facefusion.choices +from facefusion import logger, state_manager, wording +from facefusion.types import LogLevel + +LOG_LEVEL_DROPDOWN : Optional[gradio.Dropdown] = None +TERMINAL_TEXTBOX : Optional[gradio.Textbox] = None +LOG_BUFFER = io.StringIO() +LOG_HANDLER = logging.StreamHandler(LOG_BUFFER) +TQDM_UPDATE = tqdm.update + + +def render() -> None: + global LOG_LEVEL_DROPDOWN + global TERMINAL_TEXTBOX + + LOG_LEVEL_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.log_level_dropdown'), + choices = facefusion.choices.log_levels, + value = state_manager.get_item('log_level') + ) + TERMINAL_TEXTBOX = gradio.Textbox( + label = wording.get('uis.terminal_textbox'), + value = read_logs, + lines = 8, + max_lines = 8, + every = 0.5, + show_copy_button = True + ) + + +def listen() -> None: + LOG_LEVEL_DROPDOWN.change(update_log_level, inputs = LOG_LEVEL_DROPDOWN) + logger.get_package_logger().addHandler(LOG_HANDLER) + tqdm.update = tqdm_update + + +def update_log_level(log_level : LogLevel) -> None: + state_manager.set_item('log_level', log_level) + logger.init(state_manager.get_item('log_level')) + + +def tqdm_update(self : tqdm, n : int = 1) -> None: + TQDM_UPDATE(self, n) + output = create_tqdm_output(self) + + if output: + LOG_BUFFER.seek(0) + log_buffer = LOG_BUFFER.read() + lines = log_buffer.splitlines() + if lines and lines[-1].startswith(self.desc): + position = log_buffer.rfind(lines[-1]) + LOG_BUFFER.seek(position) + else: + LOG_BUFFER.seek(0, os.SEEK_END) + LOG_BUFFER.write(output + os.linesep) + LOG_BUFFER.flush() + + +def create_tqdm_output(self : tqdm) -> Optional[str]: + if not self.disable and self.desc and self.total: + percentage = math.floor(self.n / self.total * 100) + return self.desc + wording.get('colon') + ' ' + str(percentage) + '% (' + str(self.n) + '/' + str(self.total) + ')' + if not self.disable and self.desc and self.unit: + return self.desc + wording.get('colon') + ' ' + str(self.n) + ' ' + self.unit + return None + + +def read_logs() -> str: + LOG_BUFFER.seek(0) + logs = LOG_BUFFER.read().strip() + return logs diff --git a/facefusion/uis/components/trim_frame.py b/facefusion/uis/components/trim_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac384cf5a1639c1f8543a581422a55fc26f1210 --- /dev/null +++ b/facefusion/uis/components/trim_frame.py @@ -0,0 +1,62 @@ +from typing import Optional, Tuple + +from gradio_rangeslider import RangeSlider + +from facefusion import state_manager, wording +from facefusion.face_store import clear_static_faces +from facefusion.filesystem import is_video +from facefusion.uis.core import get_ui_components +from facefusion.uis.types import ComponentOptions +from facefusion.vision import count_video_frame_total + +TRIM_FRAME_RANGE_SLIDER : Optional[RangeSlider] = None + + +def render() -> None: + global TRIM_FRAME_RANGE_SLIDER + + trim_frame_range_slider_options : ComponentOptions =\ + { + 'label': wording.get('uis.trim_frame_slider'), + 'minimum': 0, + 'step': 1, + 'visible': False + } + if is_video(state_manager.get_item('target_path')): + video_frame_total = count_video_frame_total(state_manager.get_item('target_path')) + trim_frame_start = state_manager.get_item('trim_frame_start') or 0 + trim_frame_end = state_manager.get_item('trim_frame_end') or video_frame_total + trim_frame_range_slider_options['maximum'] = video_frame_total + trim_frame_range_slider_options['value'] = (trim_frame_start, trim_frame_end) + trim_frame_range_slider_options['visible'] = True + TRIM_FRAME_RANGE_SLIDER = RangeSlider(**trim_frame_range_slider_options) + + +def listen() -> None: + TRIM_FRAME_RANGE_SLIDER.release(update_trim_frame, inputs = TRIM_FRAME_RANGE_SLIDER) + for ui_component in get_ui_components( + [ + 'target_image', + 'target_video' + ]): + for method in [ 'change', 'clear' ]: + getattr(ui_component, method)(remote_update, outputs = [ TRIM_FRAME_RANGE_SLIDER ]) + + +def remote_update() -> RangeSlider: + if is_video(state_manager.get_item('target_path')): + video_frame_total = count_video_frame_total(state_manager.get_item('target_path')) + state_manager.clear_item('trim_frame_start') + state_manager.clear_item('trim_frame_end') + return RangeSlider(value = (0, video_frame_total), maximum = video_frame_total, visible = True) + return RangeSlider(visible = False) + + +def update_trim_frame(trim_frame : Tuple[float, float]) -> None: + clear_static_faces() + trim_frame_start, trim_frame_end = trim_frame + video_frame_total = count_video_frame_total(state_manager.get_item('target_path')) + trim_frame_start = int(trim_frame_start) if trim_frame_start > 0 else None + trim_frame_end = int(trim_frame_end) if trim_frame_end < video_frame_total else None + state_manager.set_item('trim_frame_start', trim_frame_start) + state_manager.set_item('trim_frame_end', trim_frame_end) diff --git a/facefusion/uis/components/ui_workflow.py b/facefusion/uis/components/ui_workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..47711a38c5be4d10bb3d0f22a36f853dd735350a --- /dev/null +++ b/facefusion/uis/components/ui_workflow.py @@ -0,0 +1,21 @@ +from typing import Optional + +import gradio + +import facefusion +from facefusion import state_manager, wording +from facefusion.uis.core import register_ui_component + +UI_WORKFLOW_DROPDOWN : Optional[gradio.Dropdown] = None + + +def render() -> None: + global UI_WORKFLOW_DROPDOWN + + UI_WORKFLOW_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.ui_workflow'), + choices = facefusion.choices.ui_workflows, + value = state_manager.get_item('ui_workflow'), + interactive = True + ) + register_ui_component('ui_workflow_dropdown', UI_WORKFLOW_DROPDOWN) diff --git a/facefusion/uis/components/webcam.py b/facefusion/uis/components/webcam.py new file mode 100644 index 0000000000000000000000000000000000000000..5547919260e1d9c62b9aca6b3cc86e1521af54df --- /dev/null +++ b/facefusion/uis/components/webcam.py @@ -0,0 +1,207 @@ +import os +import subprocess +from collections import deque +from concurrent.futures import ThreadPoolExecutor +from typing import Deque, Generator, List, Optional + +import cv2 +import gradio +from tqdm import tqdm + +from facefusion import ffmpeg_builder, logger, state_manager, wording +from facefusion.audio import create_empty_audio_frame +from facefusion.common_helper import is_windows +from facefusion.content_analyser import analyse_stream +from facefusion.face_analyser import get_average_face, get_many_faces +from facefusion.ffmpeg import open_ffmpeg +from facefusion.filesystem import filter_image_paths, is_directory +from facefusion.processors.core import get_processors_modules +from facefusion.types import Face, Fps, StreamMode, VisionFrame, WebcamMode +from facefusion.uis.core import get_ui_component +from facefusion.vision import normalize_frame_color, read_static_images, unpack_resolution + +WEBCAM_CAPTURE : Optional[cv2.VideoCapture] = None +WEBCAM_IMAGE : Optional[gradio.Image] = None +WEBCAM_START_BUTTON : Optional[gradio.Button] = None +WEBCAM_STOP_BUTTON : Optional[gradio.Button] = None + + +def get_webcam_capture(webcam_device_id : int) -> Optional[cv2.VideoCapture]: + global WEBCAM_CAPTURE + + if WEBCAM_CAPTURE is None: + cv2.setLogLevel(0) + if is_windows(): + webcam_capture = cv2.VideoCapture(webcam_device_id, cv2.CAP_DSHOW) + else: + webcam_capture = cv2.VideoCapture(webcam_device_id) + cv2.setLogLevel(3) + + if webcam_capture and webcam_capture.isOpened(): + WEBCAM_CAPTURE = webcam_capture + return WEBCAM_CAPTURE + + +def clear_webcam_capture() -> None: + global WEBCAM_CAPTURE + + if WEBCAM_CAPTURE and WEBCAM_CAPTURE.isOpened(): + WEBCAM_CAPTURE.release() + WEBCAM_CAPTURE = None + + +def render() -> None: + global WEBCAM_IMAGE + global WEBCAM_START_BUTTON + global WEBCAM_STOP_BUTTON + + WEBCAM_IMAGE = gradio.Image( + label = wording.get('uis.webcam_image') + ) + WEBCAM_START_BUTTON = gradio.Button( + value = wording.get('uis.start_button'), + variant = 'primary', + size = 'sm' + ) + WEBCAM_STOP_BUTTON = gradio.Button( + value = wording.get('uis.stop_button'), + size = 'sm' + ) + + +def listen() -> None: + webcam_device_id_dropdown = get_ui_component('webcam_device_id_dropdown') + webcam_mode_radio = get_ui_component('webcam_mode_radio') + webcam_resolution_dropdown = get_ui_component('webcam_resolution_dropdown') + webcam_fps_slider = get_ui_component('webcam_fps_slider') + source_image = get_ui_component('source_image') + + if webcam_device_id_dropdown and webcam_mode_radio and webcam_resolution_dropdown and webcam_fps_slider: + start_event = WEBCAM_START_BUTTON.click(start, inputs = [ webcam_device_id_dropdown, webcam_mode_radio, webcam_resolution_dropdown, webcam_fps_slider ], outputs = WEBCAM_IMAGE) + WEBCAM_STOP_BUTTON.click(stop, cancels = start_event, outputs = WEBCAM_IMAGE) + + if source_image: + source_image.change(stop, cancels = start_event, outputs = WEBCAM_IMAGE) + + +def start(webcam_device_id : int, webcam_mode : WebcamMode, webcam_resolution : str, webcam_fps : Fps) -> Generator[VisionFrame, None, None]: + state_manager.set_item('face_selector_mode', 'one') + source_image_paths = filter_image_paths(state_manager.get_item('source_paths')) + source_frames = read_static_images(source_image_paths) + source_faces = get_many_faces(source_frames) + source_face = get_average_face(source_faces) + stream = None + webcam_capture = None + + if webcam_mode in [ 'udp', 'v4l2' ]: + stream = open_stream(webcam_mode, webcam_resolution, webcam_fps) #type:ignore[arg-type] + webcam_width, webcam_height = unpack_resolution(webcam_resolution) + + if isinstance(webcam_device_id, int): + webcam_capture = get_webcam_capture(webcam_device_id) + + if webcam_capture and webcam_capture.isOpened(): + webcam_capture.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG')) #type:ignore[attr-defined] + webcam_capture.set(cv2.CAP_PROP_FRAME_WIDTH, webcam_width) + webcam_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, webcam_height) + webcam_capture.set(cv2.CAP_PROP_FPS, webcam_fps) + + for capture_frame in multi_process_capture(source_face, webcam_capture, webcam_fps): + capture_frame = normalize_frame_color(capture_frame) + if webcam_mode == 'inline': + yield capture_frame + else: + try: + stream.stdin.write(capture_frame.tobytes()) + except Exception: + clear_webcam_capture() + yield None + + +def multi_process_capture(source_face : Face, webcam_capture : cv2.VideoCapture, webcam_fps : Fps) -> Generator[VisionFrame, None, None]: + deque_capture_frames: Deque[VisionFrame] = deque() + + with tqdm(desc = wording.get('streaming'), unit = 'frame', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: + with ThreadPoolExecutor(max_workers = state_manager.get_item('execution_thread_count')) as executor: + futures = [] + + while webcam_capture and webcam_capture.isOpened(): + _, capture_frame = webcam_capture.read() + if analyse_stream(capture_frame, webcam_fps): + yield None + future = executor.submit(process_stream_frame, source_face, capture_frame) + futures.append(future) + + for future_done in [ future for future in futures if future.done() ]: + capture_frame = future_done.result() + deque_capture_frames.append(capture_frame) + futures.remove(future_done) + + while deque_capture_frames: + progress.update() + yield deque_capture_frames.popleft() + + +def stop() -> gradio.Image: + clear_webcam_capture() + return gradio.Image(value = None) + + +def process_stream_frame(source_face : Face, target_vision_frame : VisionFrame) -> VisionFrame: + source_audio_frame = create_empty_audio_frame() + + for processor_module in get_processors_modules(state_manager.get_item('processors')): + logger.disable() + if processor_module.pre_process('stream'): + target_vision_frame = processor_module.process_frame( + { + 'source_face': source_face, + 'source_audio_frame': source_audio_frame, + 'target_vision_frame': target_vision_frame + }) + logger.enable() + return target_vision_frame + + +def open_stream(stream_mode : StreamMode, stream_resolution : str, stream_fps : Fps) -> subprocess.Popen[bytes]: + commands = ffmpeg_builder.chain( + ffmpeg_builder.capture_video(), + ffmpeg_builder.set_media_resolution(stream_resolution), + ffmpeg_builder.set_input_fps(stream_fps) + ) + + if stream_mode == 'udp': + commands.extend(ffmpeg_builder.set_input('-')) + commands.extend(ffmpeg_builder.set_stream_mode('udp')) + commands.extend(ffmpeg_builder.set_stream_quality(2000)) + commands.extend(ffmpeg_builder.set_output('udp://localhost:27000?pkt_size=1316')) + + if stream_mode == 'v4l2': + device_directory_path = '/sys/devices/virtual/video4linux' + commands.extend(ffmpeg_builder.set_input('-')) + commands.extend(ffmpeg_builder.set_stream_mode('v4l2')) + + if is_directory(device_directory_path): + device_names = os.listdir(device_directory_path) + + for device_name in device_names: + device_path = '/dev/' + device_name + commands.extend(ffmpeg_builder.set_output(device_path)) + + else: + logger.error(wording.get('stream_not_loaded').format(stream_mode = stream_mode), __name__) + + return open_ffmpeg(commands) + + +def get_available_webcam_ids(webcam_id_start : int, webcam_id_end : int) -> List[int]: + available_webcam_ids = [] + + for index in range(webcam_id_start, webcam_id_end): + webcam_capture = get_webcam_capture(index) + + if webcam_capture and webcam_capture.isOpened(): + available_webcam_ids.append(index) + clear_webcam_capture() + + return available_webcam_ids diff --git a/facefusion/uis/components/webcam_options.py b/facefusion/uis/components/webcam_options.py new file mode 100644 index 0000000000000000000000000000000000000000..b7971c283d25e200ecb9632f9583faa049e3d012 --- /dev/null +++ b/facefusion/uis/components/webcam_options.py @@ -0,0 +1,49 @@ +from typing import Optional + +import gradio + +import facefusion.choices +from facefusion import wording +from facefusion.common_helper import get_first +from facefusion.uis.components.webcam import get_available_webcam_ids +from facefusion.uis.core import register_ui_component + +WEBCAM_DEVICE_ID_DROPDOWN : Optional[gradio.Dropdown] = None +WEBCAM_MODE_RADIO : Optional[gradio.Radio] = None +WEBCAM_RESOLUTION_DROPDOWN : Optional[gradio.Dropdown] = None +WEBCAM_FPS_SLIDER : Optional[gradio.Slider] = None + + +def render() -> None: + global WEBCAM_DEVICE_ID_DROPDOWN + global WEBCAM_MODE_RADIO + global WEBCAM_RESOLUTION_DROPDOWN + global WEBCAM_FPS_SLIDER + + available_webcam_ids = get_available_webcam_ids(0, 10) or [ 'none' ] #type:ignore[list-item] + WEBCAM_DEVICE_ID_DROPDOWN = gradio.Dropdown( + value = get_first(available_webcam_ids), + label = wording.get('uis.webcam_device_id_dropdown'), + choices = available_webcam_ids + ) + WEBCAM_MODE_RADIO = gradio.Radio( + label = wording.get('uis.webcam_mode_radio'), + choices = facefusion.choices.webcam_modes, + value = 'inline' + ) + WEBCAM_RESOLUTION_DROPDOWN = gradio.Dropdown( + label = wording.get('uis.webcam_resolution_dropdown'), + choices = facefusion.choices.webcam_resolutions, + value = facefusion.choices.webcam_resolutions[0] + ) + WEBCAM_FPS_SLIDER = gradio.Slider( + label = wording.get('uis.webcam_fps_slider'), + value = 25, + step = 1, + minimum = 1, + maximum = 60 + ) + register_ui_component('webcam_device_id_dropdown', WEBCAM_DEVICE_ID_DROPDOWN) + register_ui_component('webcam_mode_radio', WEBCAM_MODE_RADIO) + register_ui_component('webcam_resolution_dropdown', WEBCAM_RESOLUTION_DROPDOWN) + register_ui_component('webcam_fps_slider', WEBCAM_FPS_SLIDER) diff --git a/facefusion/uis/core.py b/facefusion/uis/core.py new file mode 100644 index 0000000000000000000000000000000000000000..22c37b9180dc2859d56abceb1cd7024e74706600 --- /dev/null +++ b/facefusion/uis/core.py @@ -0,0 +1,197 @@ +import importlib +import os +import warnings +from types import ModuleType +from typing import Any, Dict, List, Optional + +import gradio +from gradio.themes import Size + +import facefusion.uis.overrides as uis_overrides +from facefusion import logger, metadata, state_manager, wording +from facefusion.exit_helper import hard_exit +from facefusion.filesystem import resolve_relative_path +from facefusion.uis.types import Component, ComponentName + +UI_COMPONENTS: Dict[ComponentName, Component] = {} +UI_LAYOUT_MODULES : List[ModuleType] = [] +UI_LAYOUT_METHODS =\ +[ + 'pre_check', + 'render', + 'listen', + 'run' +] + + +def load_ui_layout_module(ui_layout : str) -> Any: + try: + ui_layout_module = importlib.import_module('facefusion.uis.layouts.' + ui_layout) + for method_name in UI_LAYOUT_METHODS: + if not hasattr(ui_layout_module, method_name): + raise NotImplementedError + except ModuleNotFoundError as exception: + logger.error(wording.get('ui_layout_not_loaded').format(ui_layout = ui_layout), __name__) + logger.debug(exception.msg, __name__) + hard_exit(1) + except NotImplementedError: + logger.error(wording.get('ui_layout_not_implemented').format(ui_layout = ui_layout), __name__) + hard_exit(1) + return ui_layout_module + + +def get_ui_layouts_modules(ui_layouts : List[str]) -> List[ModuleType]: + if not UI_LAYOUT_MODULES: + for ui_layout in ui_layouts: + ui_layout_module = load_ui_layout_module(ui_layout) + UI_LAYOUT_MODULES.append(ui_layout_module) + return UI_LAYOUT_MODULES + + +def get_ui_component(component_name : ComponentName) -> Optional[Component]: + if component_name in UI_COMPONENTS: + return UI_COMPONENTS[component_name] + return None + + +def get_ui_components(component_names : List[ComponentName]) -> Optional[List[Component]]: + ui_components = [] + + for component_name in component_names: + component = get_ui_component(component_name) + if component: + ui_components.append(component) + return ui_components + + +def register_ui_component(component_name : ComponentName, component: Component) -> None: + UI_COMPONENTS[component_name] = component + + +def init() -> None: + os.environ['GRADIO_ANALYTICS_ENABLED'] = '0' + os.environ['GRADIO_TEMP_DIR'] = os.path.join(state_manager.get_item('temp_path'), 'gradio') + + warnings.filterwarnings('ignore', category = UserWarning, module = 'gradio') + gradio.processing_utils._check_allowed = uis_overrides.check_allowed #type:ignore + gradio.processing_utils.convert_video_to_playable_mp4 = uis_overrides.convert_video_to_playable_mp4 + + +def launch() -> None: + ui_layouts_total = len(state_manager.get_item('ui_layouts')) + with gradio.Blocks(theme = get_theme(), css = get_css(), title = metadata.get('name') + ' ' + metadata.get('version'), fill_width = True) as ui: + for ui_layout in state_manager.get_item('ui_layouts'): + ui_layout_module = load_ui_layout_module(ui_layout) + + if ui_layouts_total > 1: + with gradio.Tab(ui_layout): + ui_layout_module.render() + ui_layout_module.listen() + else: + ui_layout_module.render() + ui_layout_module.listen() + + for ui_layout in state_manager.get_item('ui_layouts'): + ui_layout_module = load_ui_layout_module(ui_layout) + ui_layout_module.run(ui) + + +def get_theme() -> gradio.Theme: + return gradio.themes.Base( + primary_hue = gradio.themes.colors.red, + secondary_hue = gradio.themes.Color( + name = 'neutral', + c50 = '#fafafa', + c100 = '#f5f5f5', + c200 = '#e5e5e5', + c300 = '#d4d4d4', + c400 = '#a3a3a3', + c500 = '#737373', + c600 = '#525252', + c700 = '#404040', + c800 = '#262626', + c900 = '#212121', + c950 = '#171717', + ), + radius_size = Size( + xxs = '0.375rem', + xs = '0.375rem', + sm = '0.375rem', + md = '0.375rem', + lg = '0.375rem', + xl = '0.375rem', + xxl = '0.375rem', + ), + font = gradio.themes.GoogleFont('Open Sans') + ).set( + color_accent = 'transparent', + color_accent_soft = 'transparent', + color_accent_soft_dark = 'transparent', + background_fill_primary = '*neutral_100', + background_fill_primary_dark = '*neutral_950', + background_fill_secondary = '*neutral_50', + background_fill_secondary_dark = '*neutral_800', + block_background_fill = 'white', + block_background_fill_dark = '*neutral_900', + block_border_width = '0', + block_label_background_fill = '*neutral_100', + block_label_background_fill_dark = '*neutral_800', + block_label_border_width = 'none', + block_label_margin = '0.5rem', + block_label_radius = '*radius_md', + block_label_text_color = '*neutral_700', + block_label_text_size = '*text_sm', + block_label_text_color_dark = 'white', + block_label_text_weight = '600', + block_title_background_fill = '*neutral_100', + block_title_background_fill_dark = '*neutral_800', + block_title_padding = '*block_label_padding', + block_title_radius = '*block_label_radius', + block_title_text_color = '*neutral_700', + block_title_text_size = '*text_sm', + block_title_text_weight = '600', + block_padding = '0.5rem', + border_color_accent = 'transparent', + border_color_accent_dark = 'transparent', + border_color_accent_subdued = 'transparent', + border_color_accent_subdued_dark = 'transparent', + border_color_primary = 'transparent', + border_color_primary_dark = 'transparent', + button_large_padding = '2rem 0.5rem', + button_large_text_weight = 'normal', + button_primary_background_fill = '*primary_500', + button_primary_background_fill_dark = '*primary_600', + button_primary_text_color = 'white', + button_secondary_background_fill = 'white', + button_secondary_background_fill_dark = '*neutral_800', + button_secondary_background_fill_hover = 'white', + button_secondary_background_fill_hover_dark = '*neutral_800', + button_secondary_text_color = '*neutral_800', + button_small_padding = '0.75rem', + button_small_text_size = '0.875rem', + checkbox_background_color = '*neutral_200', + checkbox_background_color_dark = '*neutral_900', + checkbox_background_color_selected = '*primary_600', + checkbox_background_color_selected_dark = '*primary_700', + checkbox_label_background_fill = '*neutral_50', + checkbox_label_background_fill_dark = '*neutral_800', + checkbox_label_background_fill_hover = '*neutral_50', + checkbox_label_background_fill_hover_dark = '*neutral_800', + checkbox_label_background_fill_selected = '*primary_500', + checkbox_label_background_fill_selected_dark = '*primary_600', + checkbox_label_text_color_selected = 'white', + error_background_fill = 'white', + error_background_fill_dark = '*neutral_900', + error_text_color = '*primary_500', + error_text_color_dark = '*primary_600', + input_background_fill = '*neutral_50', + input_background_fill_dark = '*neutral_800', + shadow_drop = 'none', + slider_color = '*primary_500', + slider_color_dark = '*primary_600' + ) + + +def get_css() -> str: + overrides_css_path = resolve_relative_path('uis/assets/overrides.css') + return open(overrides_css_path).read() diff --git a/facefusion/uis/layouts/benchmark.py b/facefusion/uis/layouts/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..b6d6686af8da4ee4f05b4ce2aa66764060f09240 --- /dev/null +++ b/facefusion/uis/layouts/benchmark.py @@ -0,0 +1,78 @@ +import gradio + +from facefusion import state_manager +from facefusion.benchmarker import pre_check as benchmarker_pre_check +from facefusion.uis.components import about, age_modifier_options, benchmark, benchmark_options, deep_swapper_options, download, execution, execution_queue_count, execution_thread_count, expression_restorer_options, face_debugger_options, face_editor_options, face_enhancer_options, face_swapper_options, frame_colorizer_options, frame_enhancer_options, lip_syncer_options, memory, processors + + +def pre_check() -> bool: + return benchmarker_pre_check() + + +def render() -> gradio.Blocks: + with gradio.Blocks() as layout: + with gradio.Row(): + with gradio.Column(scale = 4): + with gradio.Blocks(): + about.render() + with gradio.Blocks(): + processors.render() + with gradio.Blocks(): + age_modifier_options.render() + with gradio.Blocks(): + deep_swapper_options.render() + with gradio.Blocks(): + expression_restorer_options.render() + with gradio.Blocks(): + face_debugger_options.render() + with gradio.Blocks(): + face_editor_options.render() + with gradio.Blocks(): + face_enhancer_options.render() + with gradio.Blocks(): + face_swapper_options.render() + with gradio.Blocks(): + frame_colorizer_options.render() + with gradio.Blocks(): + frame_enhancer_options.render() + with gradio.Blocks(): + lip_syncer_options.render() + with gradio.Blocks(): + execution.render() + execution_thread_count.render() + execution_queue_count.render() + with gradio.Blocks(): + download.render() + with gradio.Blocks(): + state_manager.set_item('video_memory_strategy', 'tolerant') + memory.render() + with gradio.Blocks(): + benchmark_options.render() + with gradio.Column(scale = 11): + with gradio.Blocks(): + benchmark.render() + return layout + + +def listen() -> None: + processors.listen() + age_modifier_options.listen() + deep_swapper_options.listen() + expression_restorer_options.listen() + download.listen() + face_debugger_options.listen() + face_editor_options.listen() + face_enhancer_options.listen() + face_swapper_options.listen() + frame_colorizer_options.listen() + frame_enhancer_options.listen() + lip_syncer_options.listen() + execution.listen() + execution_thread_count.listen() + execution_queue_count.listen() + memory.listen() + benchmark.listen() + + +def run(ui : gradio.Blocks) -> None: + ui.launch(favicon_path = 'facefusion.ico', inbrowser = state_manager.get_item('open_browser')) diff --git a/facefusion/uis/layouts/default.py b/facefusion/uis/layouts/default.py new file mode 100644 index 0000000000000000000000000000000000000000..96553f8a1981ae472cbbad402292bc547ba9008b --- /dev/null +++ b/facefusion/uis/layouts/default.py @@ -0,0 +1,119 @@ +import gradio + +from facefusion import state_manager +from facefusion.uis.components import about, age_modifier_options, common_options, deep_swapper_options, download, execution, execution_queue_count, execution_thread_count, expression_restorer_options, face_debugger_options, face_detector, face_editor_options, face_enhancer_options, face_landmarker, face_masker, face_selector, face_swapper_options, frame_colorizer_options, frame_enhancer_options, instant_runner, job_manager, job_runner, lip_syncer_options, memory, output, output_options, preview, processors, source, target, temp_frame, terminal, trim_frame, ui_workflow + + +def pre_check() -> bool: + return True + + +def render() -> gradio.Blocks: + with gradio.Blocks() as layout: + with gradio.Row(): + with gradio.Column(scale = 4): + with gradio.Blocks(): + about.render() + with gradio.Blocks(): + processors.render() + with gradio.Blocks(): + age_modifier_options.render() + with gradio.Blocks(): + deep_swapper_options.render() + with gradio.Blocks(): + expression_restorer_options.render() + with gradio.Blocks(): + face_debugger_options.render() + with gradio.Blocks(): + face_editor_options.render() + with gradio.Blocks(): + face_enhancer_options.render() + with gradio.Blocks(): + face_swapper_options.render() + with gradio.Blocks(): + frame_colorizer_options.render() + with gradio.Blocks(): + frame_enhancer_options.render() + with gradio.Blocks(): + lip_syncer_options.render() + with gradio.Blocks(): + execution.render() + execution_thread_count.render() + execution_queue_count.render() + with gradio.Blocks(): + download.render() + with gradio.Blocks(): + memory.render() + with gradio.Blocks(): + temp_frame.render() + with gradio.Blocks(): + output_options.render() + with gradio.Column(scale = 4): + with gradio.Blocks(): + source.render() + with gradio.Blocks(): + target.render() + with gradio.Blocks(): + output.render() + with gradio.Blocks(): + terminal.render() + with gradio.Blocks(): + ui_workflow.render() + instant_runner.render() + job_runner.render() + job_manager.render() + with gradio.Column(scale = 7): + with gradio.Blocks(): + preview.render() + with gradio.Blocks(): + trim_frame.render() + with gradio.Blocks(): + face_selector.render() + with gradio.Blocks(): + face_masker.render() + with gradio.Blocks(): + face_detector.render() + with gradio.Blocks(): + face_landmarker.render() + with gradio.Blocks(): + common_options.render() + return layout + + +def listen() -> None: + processors.listen() + age_modifier_options.listen() + deep_swapper_options.listen() + expression_restorer_options.listen() + face_debugger_options.listen() + face_editor_options.listen() + face_enhancer_options.listen() + face_swapper_options.listen() + frame_colorizer_options.listen() + frame_enhancer_options.listen() + lip_syncer_options.listen() + execution.listen() + execution_thread_count.listen() + execution_queue_count.listen() + download.listen() + memory.listen() + temp_frame.listen() + output_options.listen() + source.listen() + target.listen() + output.listen() + instant_runner.listen() + job_runner.listen() + job_manager.listen() + terminal.listen() + preview.listen() + trim_frame.listen() + face_selector.listen() + face_masker.listen() + face_detector.listen() + face_landmarker.listen() + common_options.listen() + + +def run(ui : gradio.Blocks) -> None: + ui.launch(favicon_path = 'facefusion.ico', inbrowser = state_manager.get_item('open_browser')) diff --git a/facefusion/uis/layouts/jobs.py b/facefusion/uis/layouts/jobs.py new file mode 100644 index 0000000000000000000000000000000000000000..ce38cc8ffe9b8105fd62231034356fccf01e19d7 --- /dev/null +++ b/facefusion/uis/layouts/jobs.py @@ -0,0 +1,31 @@ +import gradio + +from facefusion import state_manager +from facefusion.uis.components import about, job_list, job_list_options + + +def pre_check() -> bool: + return True + + +def render() -> gradio.Blocks: + with gradio.Blocks() as layout: + with gradio.Row(): + with gradio.Column(scale = 4): + with gradio.Blocks(): + about.render() + with gradio.Blocks(): + job_list_options.render() + with gradio.Column(scale = 11): + with gradio.Blocks(): + job_list.render() + return layout + + +def listen() -> None: + job_list_options.listen() + job_list.listen() + + +def run(ui : gradio.Blocks) -> None: + ui.launch(favicon_path = 'facefusion.ico', inbrowser = state_manager.get_item('open_browser')) diff --git a/facefusion/uis/layouts/webcam.py b/facefusion/uis/layouts/webcam.py new file mode 100644 index 0000000000000000000000000000000000000000..7b1fcd71853dc24483cd8644732bba49f364bff4 --- /dev/null +++ b/facefusion/uis/layouts/webcam.py @@ -0,0 +1,74 @@ +import gradio + +from facefusion import state_manager +from facefusion.uis.components import about, age_modifier_options, deep_swapper_options, download, execution, execution_thread_count, expression_restorer_options, face_debugger_options, face_editor_options, face_enhancer_options, face_swapper_options, frame_colorizer_options, frame_enhancer_options, lip_syncer_options, processors, source, webcam, webcam_options + + +def pre_check() -> bool: + return True + + +def render() -> gradio.Blocks: + with gradio.Blocks() as layout: + with gradio.Row(): + with gradio.Column(scale = 4): + with gradio.Blocks(): + about.render() + with gradio.Blocks(): + processors.render() + with gradio.Blocks(): + age_modifier_options.render() + with gradio.Blocks(): + deep_swapper_options.render() + with gradio.Blocks(): + expression_restorer_options.render() + with gradio.Blocks(): + face_debugger_options.render() + with gradio.Blocks(): + face_editor_options.render() + with gradio.Blocks(): + face_enhancer_options.render() + with gradio.Blocks(): + face_swapper_options.render() + with gradio.Blocks(): + frame_colorizer_options.render() + with gradio.Blocks(): + frame_enhancer_options.render() + with gradio.Blocks(): + lip_syncer_options.render() + with gradio.Blocks(): + execution.render() + execution_thread_count.render() + with gradio.Blocks(): + download.render() + with gradio.Blocks(): + webcam_options.render() + with gradio.Blocks(): + source.render() + with gradio.Column(scale = 11): + with gradio.Blocks(): + webcam.render() + return layout + + +def listen() -> None: + processors.listen() + age_modifier_options.listen() + deep_swapper_options.listen() + expression_restorer_options.listen() + download.listen() + face_debugger_options.listen() + face_editor_options.listen() + face_enhancer_options.listen() + face_swapper_options.listen() + frame_colorizer_options.listen() + frame_enhancer_options.listen() + lip_syncer_options.listen() + execution.listen() + execution_thread_count.listen() + source.listen() + webcam.listen() + + +def run(ui : gradio.Blocks) -> None: + ui.launch(favicon_path = 'facefusion.ico', inbrowser = state_manager.get_item('open_browser')) diff --git a/facefusion/uis/overrides.py b/facefusion/uis/overrides.py new file mode 100644 index 0000000000000000000000000000000000000000..585064637cc5a859cf8afa6ea1497bb4d5852b69 --- /dev/null +++ b/facefusion/uis/overrides.py @@ -0,0 +1,30 @@ +from facefusion import ffmpeg_builder +from facefusion.ffmpeg import run_ffmpeg +from facefusion.filesystem import get_file_size +from facefusion.temp_helper import create_temp_directory, get_temp_file_path + + +def convert_video_to_playable_mp4(video_path : str) -> str: + video_file_size = get_file_size(video_path) + max_file_size = 512 * 1024 * 1024 + + create_temp_directory(video_path) + temp_video_path = get_temp_file_path(video_path) + commands = ffmpeg_builder.set_input(video_path) + + if video_file_size > max_file_size: + commands.extend(ffmpeg_builder.set_video_duration(10)) + + commands.extend(ffmpeg_builder.force_output(temp_video_path)) + + process = run_ffmpeg(commands) + process.communicate() + + if process.returncode == 0: + return temp_video_path + + return video_path + + +def check_allowed(path : str, check_in_upload_folder : bool) -> None: + return None diff --git a/facefusion/uis/types.py b/facefusion/uis/types.py new file mode 100644 index 0000000000000000000000000000000000000000..159f389cd32b4156c8238d92c5d1892532dc81b4 --- /dev/null +++ b/facefusion/uis/types.py @@ -0,0 +1,85 @@ +from typing import Any, Dict, IO, Literal, TypeAlias + +File : TypeAlias = IO[Any] +ComponentName = Literal\ +[ + 'age_modifier_direction_slider', + 'age_modifier_model_dropdown', + 'benchmark_cycle_count_slider', + 'benchmark_resolutions_checkbox_group', + 'deep_swapper_model_dropdown', + 'deep_swapper_morph_slider', + 'expression_restorer_factor_slider', + 'expression_restorer_model_dropdown', + 'face_debugger_items_checkbox_group', + 'face_detector_angles_checkbox_group', + 'face_detector_model_dropdown', + 'face_detector_score_slider', + 'face_detector_size_dropdown', + 'face_editor_eyebrow_direction_slider', + 'face_editor_eye_gaze_horizontal_slider', + 'face_editor_eye_gaze_vertical_slider', + 'face_editor_eye_open_ratio_slider', + 'face_editor_head_pitch_slider', + 'face_editor_head_roll_slider', + 'face_editor_head_yaw_slider', + 'face_editor_lip_open_ratio_slider', + 'face_editor_model_dropdown', + 'face_editor_mouth_grim_slider', + 'face_editor_mouth_position_horizontal_slider', + 'face_editor_mouth_position_vertical_slider', + 'face_editor_mouth_pout_slider', + 'face_editor_mouth_purse_slider', + 'face_editor_mouth_smile_slider', + 'face_enhancer_blend_slider', + 'face_enhancer_model_dropdown', + 'face_enhancer_weight_slider', + 'face_landmarker_model_dropdown', + 'face_landmarker_score_slider', + 'face_mask_types_checkbox_group', + 'face_mask_areas_checkbox_group', + 'face_mask_regions_checkbox_group', + 'face_mask_blur_slider', + 'face_mask_padding_bottom_slider', + 'face_mask_padding_left_slider', + 'face_mask_padding_right_slider', + 'face_mask_padding_top_slider', + 'face_selector_age_range_slider', + 'face_selector_gender_dropdown', + 'face_selector_mode_dropdown', + 'face_selector_order_dropdown', + 'face_selector_race_dropdown', + 'face_swapper_model_dropdown', + 'face_swapper_pixel_boost_dropdown', + 'face_occluder_model_dropdown', + 'face_parser_model_dropdown', + 'frame_colorizer_blend_slider', + 'frame_colorizer_model_dropdown', + 'frame_colorizer_size_dropdown', + 'frame_enhancer_blend_slider', + 'frame_enhancer_model_dropdown', + 'job_list_job_status_checkbox_group', + 'lip_syncer_model_dropdown', + 'lip_syncer_weight_slider', + 'output_image', + 'output_video', + 'output_video_fps_slider', + 'preview_frame_slider', + 'processors_checkbox_group', + 'reference_face_distance_slider', + 'reference_face_position_gallery', + 'source_audio', + 'source_image', + 'target_image', + 'target_video', + 'ui_workflow_dropdown', + 'webcam_device_id_dropdown', + 'webcam_fps_slider', + 'webcam_mode_radio', + 'webcam_resolution_dropdown' +] +Component : TypeAlias = Any +ComponentOptions : TypeAlias = Dict[str, Any] + +JobManagerAction = Literal['job-create', 'job-submit', 'job-delete', 'job-add-step', 'job-remix-step', 'job-insert-step', 'job-remove-step'] +JobRunnerAction = Literal['job-run', 'job-run-all', 'job-retry', 'job-retry-all'] diff --git a/facefusion/uis/ui_helper.py b/facefusion/uis/ui_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..27bc265cdfd6cb163e6a391066051b1961e4a820 --- /dev/null +++ b/facefusion/uis/ui_helper.py @@ -0,0 +1,26 @@ +import hashlib +import os +from typing import Optional + +from facefusion import state_manager +from facefusion.filesystem import get_file_extension, is_image, is_video + + +def convert_int_none(value : int) -> Optional[int]: + if value == 'none': + return None + return value + + +def convert_str_none(value : str) -> Optional[str]: + if value == 'none': + return None + return value + + +def suggest_output_path(output_directory_path : str, target_path : str) -> Optional[str]: + if is_image(target_path) or is_video(target_path): + output_file_name = hashlib.sha1(str(state_manager.get_state()).encode()).hexdigest()[:8] + target_file_extension = get_file_extension(target_path) + return os.path.join(output_directory_path, output_file_name + target_file_extension) + return None diff --git a/facefusion/video_manager.py b/facefusion/video_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d687caf1bb1aa3dac5884201c055b6ce498f91d9 --- /dev/null +++ b/facefusion/video_manager.py @@ -0,0 +1,19 @@ +import cv2 + +from facefusion.types import VideoPoolSet + +VIDEO_POOL_SET : VideoPoolSet = {} + + +def get_video_capture(video_path : str) -> cv2.VideoCapture: + if video_path not in VIDEO_POOL_SET: + VIDEO_POOL_SET[video_path] = cv2.VideoCapture(video_path) + + return VIDEO_POOL_SET.get(video_path) + + +def clear_video_pool() -> None: + for video_capture in VIDEO_POOL_SET.values(): + video_capture.release() + + VIDEO_POOL_SET.clear() diff --git a/facefusion/vision.py b/facefusion/vision.py new file mode 100644 index 0000000000000000000000000000000000000000..b146170cdf5e74e21664dc5359677ea5e8c26fdd --- /dev/null +++ b/facefusion/vision.py @@ -0,0 +1,345 @@ +import math +from functools import lru_cache +from typing import List, Optional, Tuple + +import cv2 +import numpy +from cv2.typing import Size + +import facefusion.choices +from facefusion.common_helper import is_windows +from facefusion.filesystem import get_file_extension, is_image, is_video +from facefusion.thread_helper import thread_semaphore +from facefusion.types import Duration, Fps, Orientation, Resolution, VisionFrame +from facefusion.video_manager import get_video_capture + + +@lru_cache() +def read_static_image(image_path : str) -> Optional[VisionFrame]: + return read_image(image_path) + + +def read_static_images(image_paths : List[str]) -> List[VisionFrame]: + frames = [] + + if image_paths: + for image_path in image_paths: + frames.append(read_static_image(image_path)) + return frames + + +def read_image(image_path : str) -> Optional[VisionFrame]: + if is_image(image_path): + if is_windows(): + image_buffer = numpy.fromfile(image_path, dtype = numpy.uint8) + return cv2.imdecode(image_buffer, cv2.IMREAD_COLOR) + return cv2.imread(image_path) + return None + + +def write_image(image_path : str, vision_frame : VisionFrame) -> bool: + if image_path: + if is_windows(): + image_file_extension = get_file_extension(image_path) + _, vision_frame = cv2.imencode(image_file_extension, vision_frame) + vision_frame.tofile(image_path) + return is_image(image_path) + return cv2.imwrite(image_path, vision_frame) + return False + + +def detect_image_resolution(image_path : str) -> Optional[Resolution]: + if is_image(image_path): + image = read_image(image_path) + height, width = image.shape[:2] + + if width > 0 and height > 0: + return width, height + return None + + +def restrict_image_resolution(image_path : str, resolution : Resolution) -> Resolution: + if is_image(image_path): + image_resolution = detect_image_resolution(image_path) + if image_resolution < resolution: + return image_resolution + return resolution + + +def create_image_resolutions(resolution : Resolution) -> List[str]: + resolutions = [] + temp_resolutions = [] + + if resolution: + width, height = resolution + temp_resolutions.append(normalize_resolution(resolution)) + for image_template_size in facefusion.choices.image_template_sizes: + temp_resolutions.append(normalize_resolution((width * image_template_size, height * image_template_size))) + temp_resolutions = sorted(set(temp_resolutions)) + for temp_resolution in temp_resolutions: + resolutions.append(pack_resolution(temp_resolution)) + return resolutions + + +def read_video_frame(video_path : str, frame_number : int = 0) -> Optional[VisionFrame]: + if is_video(video_path): + video_capture = get_video_capture(video_path) + + if video_capture.isOpened(): + frame_total = video_capture.get(cv2.CAP_PROP_FRAME_COUNT) + + with thread_semaphore(): + video_capture.set(cv2.CAP_PROP_POS_FRAMES, min(frame_total, frame_number - 1)) + has_vision_frame, vision_frame = video_capture.read() + + if has_vision_frame: + return vision_frame + + return None + + +def count_video_frame_total(video_path : str) -> int: + if is_video(video_path): + video_capture = get_video_capture(video_path) + + if video_capture.isOpened(): + with thread_semaphore(): + video_frame_total = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) + return video_frame_total + + return 0 + + +def predict_video_frame_total(video_path : str, fps : Fps, trim_frame_start : int, trim_frame_end : int) -> int: + if is_video(video_path): + video_fps = detect_video_fps(video_path) + extract_frame_total = count_trim_frame_total(video_path, trim_frame_start, trim_frame_end) * fps / video_fps + return math.floor(extract_frame_total) + return 0 + + +def detect_video_fps(video_path : str) -> Optional[float]: + if is_video(video_path): + video_capture = get_video_capture(video_path) + + if video_capture.isOpened(): + with thread_semaphore(): + video_fps = video_capture.get(cv2.CAP_PROP_FPS) + return video_fps + + return None + + +def restrict_video_fps(video_path : str, fps : Fps) -> Fps: + if is_video(video_path): + video_fps = detect_video_fps(video_path) + if video_fps < fps: + return video_fps + return fps + + +def detect_video_duration(video_path : str) -> Duration: + video_frame_total = count_video_frame_total(video_path) + video_fps = detect_video_fps(video_path) + + if video_frame_total and video_fps: + return video_frame_total / video_fps + return 0 + + +def count_trim_frame_total(video_path : str, trim_frame_start : Optional[int], trim_frame_end : Optional[int]) -> int: + trim_frame_start, trim_frame_end = restrict_trim_frame(video_path, trim_frame_start, trim_frame_end) + + return trim_frame_end - trim_frame_start + + +def restrict_trim_frame(video_path : str, trim_frame_start : Optional[int], trim_frame_end : Optional[int]) -> Tuple[int, int]: + video_frame_total = count_video_frame_total(video_path) + + if isinstance(trim_frame_start, int): + trim_frame_start = max(0, min(trim_frame_start, video_frame_total)) + if isinstance(trim_frame_end, int): + trim_frame_end = max(0, min(trim_frame_end, video_frame_total)) + + if isinstance(trim_frame_start, int) and isinstance(trim_frame_end, int): + return trim_frame_start, trim_frame_end + if isinstance(trim_frame_start, int): + return trim_frame_start, video_frame_total + if isinstance(trim_frame_end, int): + return 0, trim_frame_end + + return 0, video_frame_total + + +def detect_video_resolution(video_path : str) -> Optional[Resolution]: + if is_video(video_path): + video_capture = get_video_capture(video_path) + + if video_capture.isOpened(): + with thread_semaphore(): + width = video_capture.get(cv2.CAP_PROP_FRAME_WIDTH) + height = video_capture.get(cv2.CAP_PROP_FRAME_HEIGHT) + return int(width), int(height) + + return None + + +def restrict_video_resolution(video_path : str, resolution : Resolution) -> Resolution: + if is_video(video_path): + video_resolution = detect_video_resolution(video_path) + if video_resolution < resolution: + return video_resolution + return resolution + + +def create_video_resolutions(resolution : Resolution) -> List[str]: + resolutions = [] + temp_resolutions = [] + + if resolution: + width, height = resolution + temp_resolutions.append(normalize_resolution(resolution)) + for video_template_size in facefusion.choices.video_template_sizes: + if width > height: + temp_resolutions.append(normalize_resolution((video_template_size * width / height, video_template_size))) + else: + temp_resolutions.append(normalize_resolution((video_template_size, video_template_size * height / width))) + temp_resolutions = sorted(set(temp_resolutions)) + for temp_resolution in temp_resolutions: + resolutions.append(pack_resolution(temp_resolution)) + return resolutions + + +def normalize_resolution(resolution : Tuple[float, float]) -> Resolution: + width, height = resolution + + if width > 0 and height > 0: + normalize_width = round(width / 2) * 2 + normalize_height = round(height / 2) * 2 + return normalize_width, normalize_height + return 0, 0 + + +def pack_resolution(resolution : Resolution) -> str: + width, height = normalize_resolution(resolution) + return str(width) + 'x' + str(height) + + +def unpack_resolution(resolution : str) -> Resolution: + width, height = map(int, resolution.split('x')) + return width, height + + +def detect_frame_orientation(vision_frame : VisionFrame) -> Orientation: + height, width = vision_frame.shape[:2] + + if width > height: + return 'landscape' + return 'portrait' + + +def restrict_frame(vision_frame : VisionFrame, resolution : Resolution) -> VisionFrame: + height, width = vision_frame.shape[:2] + restrict_width, restrict_height = resolution + + if height > restrict_height or width > restrict_width: + scale = min(restrict_height / height, restrict_width / width) + new_width = int(width * scale) + new_height = int(height * scale) + return cv2.resize(vision_frame, (new_width, new_height)) + return vision_frame + + +def fit_frame(vision_frame : VisionFrame, resolution: Resolution) -> VisionFrame: + fit_width, fit_height = resolution + height, width = vision_frame.shape[:2] + scale = min(fit_height / height, fit_width / width) + new_width = int(width * scale) + new_height = int(height * scale) + paste_vision_frame = cv2.resize(vision_frame, (new_width, new_height)) + x_pad = (fit_width - new_width) // 2 + y_pad = (fit_height - new_height) // 2 + temp_vision_frame = numpy.pad(paste_vision_frame, ((y_pad, fit_height - new_height - y_pad), (x_pad, fit_width - new_width - x_pad), (0, 0))) + return temp_vision_frame + + +def normalize_frame_color(vision_frame : VisionFrame) -> VisionFrame: + return cv2.cvtColor(vision_frame, cv2.COLOR_BGR2RGB) + + +def conditional_match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: + histogram_factor = calc_histogram_difference(source_vision_frame, target_vision_frame) + target_vision_frame = blend_vision_frames(target_vision_frame, match_frame_color(source_vision_frame, target_vision_frame), histogram_factor) + return target_vision_frame + + +def match_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> VisionFrame: + color_difference_sizes = numpy.linspace(16, target_vision_frame.shape[0], 3, endpoint = False) + + for color_difference_size in color_difference_sizes: + source_vision_frame = equalize_frame_color(source_vision_frame, target_vision_frame, normalize_resolution((color_difference_size, color_difference_size))) + target_vision_frame = equalize_frame_color(source_vision_frame, target_vision_frame, target_vision_frame.shape[:2][::-1]) + return target_vision_frame + + +def equalize_frame_color(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, size : Size) -> VisionFrame: + source_frame_resize = cv2.resize(source_vision_frame, size, interpolation = cv2.INTER_AREA).astype(numpy.float32) + target_frame_resize = cv2.resize(target_vision_frame, size, interpolation = cv2.INTER_AREA).astype(numpy.float32) + color_difference_vision_frame = numpy.subtract(source_frame_resize, target_frame_resize) + color_difference_vision_frame = cv2.resize(color_difference_vision_frame, target_vision_frame.shape[:2][::-1], interpolation = cv2.INTER_CUBIC) + target_vision_frame = numpy.add(target_vision_frame, color_difference_vision_frame).clip(0, 255).astype(numpy.uint8) + return target_vision_frame + + +def calc_histogram_difference(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame) -> float: + histogram_source = cv2.calcHist([cv2.cvtColor(source_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) + histogram_target = cv2.calcHist([cv2.cvtColor(target_vision_frame, cv2.COLOR_BGR2HSV)], [ 0, 1 ], None, [ 50, 60 ], [ 0, 180, 0, 256 ]) + histogram_difference = float(numpy.interp(cv2.compareHist(histogram_source, histogram_target, cv2.HISTCMP_CORREL), [ -1, 1 ], [ 0, 1 ])) + return histogram_difference + + +def blend_vision_frames(source_vision_frame : VisionFrame, target_vision_frame : VisionFrame, blend_factor : float) -> VisionFrame: + blend_vision_frame = cv2.addWeighted(source_vision_frame, 1 - blend_factor, target_vision_frame, blend_factor, 0) + return blend_vision_frame + + +def create_tile_frames(vision_frame : VisionFrame, size : Size) -> Tuple[List[VisionFrame], int, int]: + vision_frame = numpy.pad(vision_frame, ((size[1], size[1]), (size[1], size[1]), (0, 0))) + tile_width = size[0] - 2 * size[2] + pad_size_bottom = size[2] + tile_width - vision_frame.shape[0] % tile_width + pad_size_right = size[2] + tile_width - vision_frame.shape[1] % tile_width + pad_vision_frame = numpy.pad(vision_frame, ((size[2], pad_size_bottom), (size[2], pad_size_right), (0, 0))) + pad_height, pad_width = pad_vision_frame.shape[:2] + row_range = range(size[2], pad_height - size[2], tile_width) + col_range = range(size[2], pad_width - size[2], tile_width) + tile_vision_frames = [] + + for row_vision_frame in row_range: + top = row_vision_frame - size[2] + bottom = row_vision_frame + size[2] + tile_width + + for column_vision_frame in col_range: + left = column_vision_frame - size[2] + right = column_vision_frame + size[2] + tile_width + tile_vision_frames.append(pad_vision_frame[top:bottom, left:right, :]) + + return tile_vision_frames, pad_width, pad_height + + +def merge_tile_frames(tile_vision_frames : List[VisionFrame], temp_width : int, temp_height : int, pad_width : int, pad_height : int, size : Size) -> VisionFrame: + merge_vision_frame = numpy.zeros((pad_height, pad_width, 3)).astype(numpy.uint8) + tile_width = tile_vision_frames[0].shape[1] - 2 * size[2] + tiles_per_row = min(pad_width // tile_width, len(tile_vision_frames)) + + for index, tile_vision_frame in enumerate(tile_vision_frames): + tile_vision_frame = tile_vision_frame[size[2]:-size[2], size[2]:-size[2]] + row_index = index // tiles_per_row + col_index = index % tiles_per_row + top = row_index * tile_vision_frame.shape[0] + bottom = top + tile_vision_frame.shape[0] + left = col_index * tile_vision_frame.shape[1] + right = left + tile_vision_frame.shape[1] + merge_vision_frame[top:bottom, left:right, :] = tile_vision_frame + + merge_vision_frame = merge_vision_frame[size[1] : size[1] + temp_height, size[1]: size[1] + temp_width, :] + return merge_vision_frame diff --git a/facefusion/voice_extractor.py b/facefusion/voice_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..6fca54a1d4e5415ca24590167f871937c0f6b129 --- /dev/null +++ b/facefusion/voice_extractor.py @@ -0,0 +1,149 @@ +from functools import lru_cache +from typing import Tuple + +import numpy +import scipy + +from facefusion import inference_manager +from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.filesystem import resolve_relative_path +from facefusion.thread_helper import thread_semaphore +from facefusion.types import Audio, AudioChunk, DownloadScope, InferencePool, ModelOptions, ModelSet + + +@lru_cache(maxsize = None) +def create_static_model_set(download_scope : DownloadScope) -> ModelSet: + return\ + { + 'kim_vocal_2': + { + 'hashes': + { + 'voice_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'kim_vocal_2.hash'), + 'path': resolve_relative_path('../.assets/models/kim_vocal_2.hash') + } + }, + 'sources': + { + 'voice_extractor': + { + 'url': resolve_download_url('models-3.0.0', 'kim_vocal_2.onnx'), + 'path': resolve_relative_path('../.assets/models/kim_vocal_2.onnx') + } + } + } + } + + +def get_inference_pool() -> InferencePool: + model_names = [ 'kim_vocal_2' ] + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) + + +def clear_inference_pool() -> None: + model_names = [ 'kim_vocal_2' ] + inference_manager.clear_inference_pool(__name__, model_names) + + +def get_model_options() -> ModelOptions: + return create_static_model_set('full').get('kim_vocal_2') + + +def pre_check() -> bool: + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') + + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) + + +def batch_extract_voice(audio : Audio, chunk_size : int, step_size : int) -> Audio: + temp_audio = numpy.zeros((audio.shape[0], 2)).astype(numpy.float32) + temp_chunk = numpy.zeros((audio.shape[0], 2)).astype(numpy.float32) + + for start in range(0, audio.shape[0], step_size): + end = min(start + chunk_size, audio.shape[0]) + temp_audio[start:end, ...] += extract_voice(audio[start:end, ...]) + temp_chunk[start:end, ...] += 1 + + audio = temp_audio / temp_chunk + return audio + + +def extract_voice(temp_audio_chunk : AudioChunk) -> AudioChunk: + voice_extractor = get_inference_pool().get('voice_extractor') + chunk_size = (voice_extractor.get_inputs()[0].shape[3] - 1) * 1024 + trim_size = 3840 + temp_audio_chunk, pad_size = prepare_audio_chunk(temp_audio_chunk.T, chunk_size, trim_size) + temp_audio_chunk = decompose_audio_chunk(temp_audio_chunk, trim_size) + temp_audio_chunk = forward(temp_audio_chunk) + temp_audio_chunk = compose_audio_chunk(temp_audio_chunk, trim_size) + temp_audio_chunk = normalize_audio_chunk(temp_audio_chunk, chunk_size, trim_size, pad_size) + return temp_audio_chunk + + +def forward(temp_audio_chunk : AudioChunk) -> AudioChunk: + voice_extractor = get_inference_pool().get('voice_extractor') + + with thread_semaphore(): + temp_audio_chunk = voice_extractor.run(None, + { + 'input': temp_audio_chunk + })[0] + + return temp_audio_chunk + + +def prepare_audio_chunk(temp_audio_chunk : AudioChunk, chunk_size : int, trim_size : int) -> Tuple[AudioChunk, int]: + step_size = chunk_size - 2 * trim_size + pad_size = step_size - temp_audio_chunk.shape[1] % step_size + audio_chunk_size = temp_audio_chunk.shape[1] + pad_size + temp_audio_chunk = temp_audio_chunk.astype(numpy.float32) / numpy.iinfo(numpy.int16).max + temp_audio_chunk = numpy.pad(temp_audio_chunk, ((0, 0), (trim_size, trim_size + pad_size))) + temp_audio_chunks = [] + + for index in range(0, audio_chunk_size, step_size): + temp_audio_chunks.append(temp_audio_chunk[:, index:index + chunk_size]) + + temp_audio_chunk = numpy.concatenate(temp_audio_chunks, axis = 0) + temp_audio_chunk = temp_audio_chunk.reshape((-1, chunk_size)) + return temp_audio_chunk, pad_size + + +def decompose_audio_chunk(temp_audio_chunk : AudioChunk, trim_size : int) -> AudioChunk: + frame_size = 7680 + frame_overlap = 6656 + frame_total = 3072 + bin_total = 256 + channel_total = 4 + window = scipy.signal.windows.hann(frame_size) + temp_audio_chunk = scipy.signal.stft(temp_audio_chunk, nperseg = frame_size, noverlap = frame_overlap, window = window)[2] + temp_audio_chunk = numpy.stack((numpy.real(temp_audio_chunk), numpy.imag(temp_audio_chunk)), axis = -1).transpose((0, 3, 1, 2)) + temp_audio_chunk = temp_audio_chunk.reshape(-1, 2, 2, trim_size + 1, bin_total).reshape(-1, channel_total, trim_size + 1, bin_total) + temp_audio_chunk = temp_audio_chunk[:, :, :frame_total] + temp_audio_chunk /= numpy.sqrt(1.0 / window.sum() ** 2) + return temp_audio_chunk + + +def compose_audio_chunk(temp_audio_chunk : AudioChunk, trim_size : int) -> AudioChunk: + frame_size = 7680 + frame_overlap = 6656 + frame_total = 3072 + bin_total = 256 + window = scipy.signal.windows.hann(frame_size) + temp_audio_chunk = numpy.pad(temp_audio_chunk, ((0, 0), (0, 0), (0, trim_size + 1 - frame_total), (0, 0))) + temp_audio_chunk = temp_audio_chunk.reshape(-1, 2, trim_size + 1, bin_total).transpose((0, 2, 3, 1)) + temp_audio_chunk = temp_audio_chunk[:, :, :, 0] + 1j * temp_audio_chunk[:, :, :, 1] + temp_audio_chunk = scipy.signal.istft(temp_audio_chunk, nperseg = frame_size, noverlap = frame_overlap, window = window)[1] + temp_audio_chunk *= numpy.sqrt(1.0 / window.sum() ** 2) + return temp_audio_chunk + + +def normalize_audio_chunk(temp_audio_chunk : AudioChunk, chunk_size : int, trim_size : int, pad_size : int) -> AudioChunk: + temp_audio_chunk = temp_audio_chunk.reshape((-1, 2, chunk_size)) + temp_audio_chunk = temp_audio_chunk[:, :, trim_size:-trim_size].transpose(1, 0, 2) + temp_audio_chunk = temp_audio_chunk.reshape(2, -1)[:, :-pad_size].T + return temp_audio_chunk diff --git a/facefusion/wording.py b/facefusion/wording.py new file mode 100644 index 0000000000000000000000000000000000000000..8094f104b9b00a6ea007d8669b180c20e34684c1 --- /dev/null +++ b/facefusion/wording.py @@ -0,0 +1,361 @@ +from typing import Any, Dict, Optional + +WORDING : Dict[str, Any] =\ +{ + 'conda_not_activated': 'Conda is not activated', + 'python_not_supported': 'Python version is not supported, upgrade to {version} or higher', + 'curl_not_installed': 'cURL is not installed', + 'ffmpeg_not_installed': 'FFMpeg is not installed', + 'creating_temp': 'Creating temporary resources', + 'extracting_frames': 'Extracting frames with a resolution of {resolution} and {fps} frames per second', + 'extracting_frames_succeed': 'Extracting frames succeed', + 'extracting_frames_failed': 'Extracting frames failed', + 'analysing': 'Analysing', + 'extracting': 'Extracting', + 'streaming': 'Streaming', + 'processing': 'Processing', + 'merging': 'Merging', + 'downloading': 'Downloading', + 'temp_frames_not_found': 'Temporary frames not found', + 'copying_image': 'Copying image with a resolution of {resolution}', + 'copying_image_succeed': 'Copying image succeed', + 'copying_image_failed': 'Copying image failed', + 'finalizing_image': 'Finalizing image with a resolution of {resolution}', + 'finalizing_image_succeed': 'Finalizing image succeed', + 'finalizing_image_skipped': 'Finalizing image skipped', + 'merging_video': 'Merging video with a resolution of {resolution} and {fps} frames per second', + 'merging_video_succeed': 'Merging video succeed', + 'merging_video_failed': 'Merging video failed', + 'skipping_audio': 'Skipping audio', + 'replacing_audio_succeed': 'Replacing audio succeed', + 'replacing_audio_skipped': 'Replacing audio skipped', + 'restoring_audio_succeed': 'Restoring audio succeed', + 'restoring_audio_skipped': 'Restoring audio skipped', + 'clearing_temp': 'Clearing temporary resources', + 'processing_stopped': 'Processing stopped', + 'processing_image_succeed': 'Processing to image succeed in {seconds} seconds', + 'processing_image_failed': 'Processing to image failed', + 'processing_video_succeed': 'Processing to video succeed in {seconds} seconds', + 'processing_video_failed': 'Processing to video failed', + 'choose_image_source': 'Choose a image for the source', + 'choose_audio_source': 'Choose a audio for the source', + 'choose_video_target': 'Choose a video for the target', + 'choose_image_or_video_target': 'Choose a image or video for the target', + 'specify_image_or_video_output': 'Specify the output image or video within a directory', + 'match_target_and_output_extension': 'Match the target and output extension', + 'no_source_face_detected': 'No source face detected', + 'processor_not_loaded': 'Processor {processor} could not be loaded', + 'processor_not_implemented': 'Processor {processor} not implemented correctly', + 'ui_layout_not_loaded': 'UI layout {ui_layout} could not be loaded', + 'ui_layout_not_implemented': 'UI layout {ui_layout} not implemented correctly', + 'stream_not_loaded': 'Stream {stream_mode} could not be loaded', + 'stream_not_supported': 'Stream not supported', + 'job_created': 'Job {job_id} created', + 'job_not_created': 'Job {job_id} not created', + 'job_submitted': 'Job {job_id} submitted', + 'job_not_submitted': 'Job {job_id} not submitted', + 'job_all_submitted': 'Jobs submitted', + 'job_all_not_submitted': 'Jobs not submitted', + 'job_deleted': 'Job {job_id} deleted', + 'job_not_deleted': 'Job {job_id} not deleted', + 'job_all_deleted': 'Jobs deleted', + 'job_all_not_deleted': 'Jobs not deleted', + 'job_step_added': 'Step added to job {job_id}', + 'job_step_not_added': 'Step not added to job {job_id}', + 'job_remix_step_added': 'Step {step_index} remixed from job {job_id}', + 'job_remix_step_not_added': 'Step {step_index} not remixed from job {job_id}', + 'job_step_inserted': 'Step {step_index} inserted to job {job_id}', + 'job_step_not_inserted': 'Step {step_index} not inserted to job {job_id}', + 'job_step_removed': 'Step {step_index} removed from job {job_id}', + 'job_step_not_removed': 'Step {step_index} not removed from job {job_id}', + 'running_job': 'Running queued job {job_id}', + 'running_jobs': 'Running all queued jobs', + 'retrying_job': 'Retrying failed job {job_id}', + 'retrying_jobs': 'Retrying all failed jobs', + 'processing_job_succeed': 'Processing of job {job_id} succeed', + 'processing_jobs_succeed': 'Processing of all job succeed', + 'processing_job_failed': 'Processing of job {job_id} failed', + 'processing_jobs_failed': 'Processing of all jobs failed', + 'processing_step': 'Processing step {step_current} of {step_total}', + 'validating_hash_succeed': 'Validating hash for {hash_file_name} succeed', + 'validating_hash_failed': 'Validating hash for {hash_file_name} failed', + 'validating_source_succeed': 'Validating source for {source_file_name} succeed', + 'validating_source_failed': 'Validating source for {source_file_name} failed', + 'deleting_corrupt_source': 'Deleting corrupt source for {source_file_name}', + 'time_ago_now': 'just now', + 'time_ago_minutes': '{minutes} minutes ago', + 'time_ago_hours': '{hours} hours and {minutes} minutes ago', + 'time_ago_days': '{days} days, {hours} hours and {minutes} minutes ago', + 'point': '.', + 'comma': ',', + 'colon': ':', + 'question_mark': '?', + 'exclamation_mark': '!', + 'help': + { + # installer + 'install_dependency': 'choose the variant of {dependency} to install', + 'skip_conda': 'skip the conda environment check', + # paths + 'config_path': 'choose the config file to override defaults', + 'temp_path': 'specify the directory for the temporary resources', + 'jobs_path': 'specify the directory to store jobs', + 'source_paths': 'choose the image or audio paths', + 'target_path': 'choose the image or video path', + 'output_path': 'specify the image or video within a directory', + # patterns + 'source_pattern': 'choose the image or audio pattern', + 'target_pattern': 'choose the image or video pattern', + 'output_pattern': 'specify the image or video pattern', + # face detector + 'face_detector_model': 'choose the model responsible for detecting the faces', + 'face_detector_size': 'specify the frame size provided to the face detector', + 'face_detector_angles': 'specify the angles to rotate the frame before detecting faces', + 'face_detector_score': 'filter the detected faces base on the confidence score', + # face landmarker + 'face_landmarker_model': 'choose the model responsible for detecting the face landmarks', + 'face_landmarker_score': 'filter the detected face landmarks base on the confidence score', + # face selector + 'face_selector_mode': 'use reference based tracking or simple matching', + 'face_selector_order': 'specify the order of the detected faces', + 'face_selector_age_start': 'filter the detected faces based the starting age', + 'face_selector_age_end': 'filter the detected faces based the ending age', + 'face_selector_gender': 'filter the detected faces based on their gender', + 'face_selector_race': 'filter the detected faces based on their race', + 'reference_face_position': 'specify the position used to create the reference face', + 'reference_face_distance': 'specify the similarity between the reference face and target face', + 'reference_frame_number': 'specify the frame used to create the reference face', + # face masker + 'face_occluder_model': 'choose the model responsible for the occlusion mask', + 'face_parser_model': 'choose the model responsible for the region mask', + 'face_mask_types': 'mix and match different face mask types (choices: {choices})', + 'face_mask_areas': 'choose the items used for the area mask (choices: {choices})', + 'face_mask_regions': 'choose the items used for the region mask (choices: {choices})', + 'face_mask_blur': 'specify the degree of blur applied to the box mask', + 'face_mask_padding': 'apply top, right, bottom and left padding to the box mask', + # frame extraction + 'trim_frame_start': 'specify the starting frame of the target video', + 'trim_frame_end': 'specify the ending frame of the target video', + 'temp_frame_format': 'specify the temporary resources format', + 'keep_temp': 'keep the temporary resources after processing', + # output creation + 'output_image_quality': 'specify the image quality which translates to the image compression', + 'output_image_resolution': 'specify the image resolution based on the target image', + 'output_audio_encoder': 'specify the encoder used for the audio', + 'output_audio_quality': 'specify the audio quality which translates to the audio compression', + 'output_audio_volume': 'specify the audio volume based on the target video', + 'output_video_encoder': 'specify the encoder used for the video', + 'output_video_preset': 'balance fast video processing and video file size', + 'output_video_quality': 'specify the video quality which translates to the video compression', + 'output_video_resolution': 'specify the video resolution based on the target video', + 'output_video_fps': 'specify the video fps based on the target video', + # processors + 'processors': 'load a single or multiple processors (choices: {choices}, ...)', + 'age_modifier_model': 'choose the model responsible for aging the face', + 'age_modifier_direction': 'specify the direction in which the age should be modified', + 'deep_swapper_model': 'choose the model responsible for swapping the face', + 'deep_swapper_morph': 'morph between source face and target faces', + 'expression_restorer_model': 'choose the model responsible for restoring the expression', + 'expression_restorer_factor': 'restore factor of expression from the target face', + 'face_debugger_items': 'load a single or multiple processors (choices: {choices})', + 'face_editor_model': 'choose the model responsible for editing the face', + 'face_editor_eyebrow_direction': 'specify the eyebrow direction', + 'face_editor_eye_gaze_horizontal': 'specify the horizontal eye gaze', + 'face_editor_eye_gaze_vertical': 'specify the vertical eye gaze', + 'face_editor_eye_open_ratio': 'specify the ratio of eye opening', + 'face_editor_lip_open_ratio': 'specify the ratio of lip opening', + 'face_editor_mouth_grim': 'specify the mouth grim', + 'face_editor_mouth_pout': 'specify the mouth pout', + 'face_editor_mouth_purse': 'specify the mouth purse', + 'face_editor_mouth_smile': 'specify the mouth smile', + 'face_editor_mouth_position_horizontal': 'specify the horizontal mouth position', + 'face_editor_mouth_position_vertical': 'specify the vertical mouth position', + 'face_editor_head_pitch': 'specify the head pitch', + 'face_editor_head_yaw': 'specify the head yaw', + 'face_editor_head_roll': 'specify the head roll', + 'face_enhancer_model': 'choose the model responsible for enhancing the face', + 'face_enhancer_blend': 'blend the enhanced into the previous face', + 'face_enhancer_weight': 'specify the degree of weight applied to the face', + 'face_swapper_model': 'choose the model responsible for swapping the face', + 'face_swapper_pixel_boost': 'choose the pixel boost resolution for the face swapper', + 'frame_colorizer_model': 'choose the model responsible for colorizing the frame', + 'frame_colorizer_size': 'specify the frame size provided to the frame colorizer', + 'frame_colorizer_blend': 'blend the colorized into the previous frame', + 'frame_enhancer_model': 'choose the model responsible for enhancing the frame', + 'frame_enhancer_blend': 'blend the enhanced into the previous frame', + 'lip_syncer_model': 'choose the model responsible for syncing the lips', + 'lip_syncer_weight': 'specify the degree of weight applied to the lips', + # uis + 'open_browser': 'open the browser once the program is ready', + 'ui_layouts': 'launch a single or multiple UI layouts (choices: {choices}, ...)', + 'ui_workflow': 'choose the ui workflow', + # download + 'download_providers': 'download using different providers (choices: {choices}, ...)', + 'download_scope': 'specify the download scope', + # benchmark + 'benchmark_resolutions': 'choose the resolutions for the benchmarks (choices: {choices}, ...)', + 'benchmark_cycle_count': 'specify the amount of cycles per benchmark', + # execution + 'execution_device_id': 'specify the device used for processing', + 'execution_providers': 'inference using different providers (choices: {choices}, ...)', + 'execution_thread_count': 'specify the amount of parallel threads while processing', + 'execution_queue_count': 'specify the amount of frames each thread is processing', + # memory + 'video_memory_strategy': 'balance fast processing and low VRAM usage', + 'system_memory_limit': 'limit the available RAM that can be used while processing', + # misc + 'log_level': 'adjust the message severity displayed in the terminal', + 'halt_on_error': 'halt the program once an error occurred', + # run + 'run': 'run the program', + 'headless_run': 'run the program in headless mode', + 'batch_run': 'run the program in batch mode', + 'force_download': 'force automate downloads and exit', + 'benchmark': 'benchmark the program', + # jobs + 'job_id': 'specify the job id', + 'job_status': 'specify the job status', + 'step_index': 'specify the step index', + # job manager + 'job_list': 'list jobs by status', + 'job_create': 'create a drafted job', + 'job_submit': 'submit a drafted job to become a queued job', + 'job_submit_all': 'submit all drafted jobs to become a queued jobs', + 'job_delete': 'delete a drafted, queued, failed or completed job', + 'job_delete_all': 'delete all drafted, queued, failed and completed jobs', + 'job_add_step': 'add a step to a drafted job', + 'job_remix_step': 'remix a previous step from a drafted job', + 'job_insert_step': 'insert a step to a drafted job', + 'job_remove_step': 'remove a step from a drafted job', + # job runner + 'job_run': 'run a queued job', + 'job_run_all': 'run all queued jobs', + 'job_retry': 'retry a failed job', + 'job_retry_all': 'retry all failed jobs' + }, + 'about': + { + 'become_a_member': 'become a member', + 'join_our_community': 'join our community', + 'read_the_documentation': 'read the documentation' + }, + 'uis': + { + 'age_modifier_direction_slider': 'AGE MODIFIER DIRECTION', + 'age_modifier_model_dropdown': 'AGE MODIFIER MODEL', + 'apply_button': 'APPLY', + 'benchmark_cycle_count_slider': 'BENCHMARK CYCLE COUNT', + 'benchmark_resolutions_checkbox_group': 'BENCHMARK RESOLUTIONS', + 'clear_button': 'CLEAR', + 'common_options_checkbox_group': 'OPTIONS', + 'download_providers_checkbox_group': 'DOWNLOAD PROVIDERS', + 'deep_swapper_model_dropdown': 'DEEP SWAPPER MODEL', + 'deep_swapper_morph_slider': 'DEEP SWAPPER MORPH', + 'execution_providers_checkbox_group': 'EXECUTION PROVIDERS', + 'execution_queue_count_slider': 'EXECUTION QUEUE COUNT', + 'execution_thread_count_slider': 'EXECUTION THREAD COUNT', + 'expression_restorer_factor_slider': 'EXPRESSION RESTORER FACTOR', + 'expression_restorer_model_dropdown': 'EXPRESSION RESTORER MODEL', + 'face_debugger_items_checkbox_group': 'FACE DEBUGGER ITEMS', + 'face_detector_angles_checkbox_group': 'FACE DETECTOR ANGLES', + 'face_detector_model_dropdown': 'FACE DETECTOR MODEL', + 'face_detector_score_slider': 'FACE DETECTOR SCORE', + 'face_detector_size_dropdown': 'FACE DETECTOR SIZE', + 'face_editor_eyebrow_direction_slider': 'FACE EDITOR EYEBROW DIRECTION', + 'face_editor_eye_gaze_horizontal_slider': 'FACE EDITOR EYE GAZE HORIZONTAL', + 'face_editor_eye_gaze_vertical_slider': 'FACE EDITOR EYE GAZE VERTICAL', + 'face_editor_eye_open_ratio_slider': 'FACE EDITOR EYE OPEN RATIO', + 'face_editor_head_pitch_slider': 'FACE EDITOR HEAD PITCH', + 'face_editor_head_roll_slider': 'FACE EDITOR HEAD ROLL', + 'face_editor_head_yaw_slider': 'FACE EDITOR HEAD YAW', + 'face_editor_lip_open_ratio_slider': 'FACE EDITOR LIP OPEN RATIO', + 'face_editor_model_dropdown': 'FACE EDITOR MODEL', + 'face_editor_mouth_grim_slider': 'FACE EDITOR MOUTH GRIM', + 'face_editor_mouth_position_horizontal_slider': 'FACE EDITOR MOUTH POSITION HORIZONTAL', + 'face_editor_mouth_position_vertical_slider': 'FACE EDITOR MOUTH POSITION VERTICAL', + 'face_editor_mouth_pout_slider': 'FACE EDITOR MOUTH POUT', + 'face_editor_mouth_purse_slider': 'FACE EDITOR MOUTH PURSE', + 'face_editor_mouth_smile_slider': 'FACE EDITOR MOUTH SMILE', + 'face_enhancer_blend_slider': 'FACE ENHANCER BLEND', + 'face_enhancer_model_dropdown': 'FACE ENHANCER MODEL', + 'face_enhancer_weight_slider': 'FACE ENHANCER WEIGHT', + 'face_landmarker_model_dropdown': 'FACE LANDMARKER MODEL', + 'face_landmarker_score_slider': 'FACE LANDMARKER SCORE', + 'face_mask_blur_slider': 'FACE MASK BLUR', + 'face_mask_padding_bottom_slider': 'FACE MASK PADDING BOTTOM', + 'face_mask_padding_left_slider': 'FACE MASK PADDING LEFT', + 'face_mask_padding_right_slider': 'FACE MASK PADDING RIGHT', + 'face_mask_padding_top_slider': 'FACE MASK PADDING TOP', + 'face_mask_areas_checkbox_group': 'FACE MASK AREAS', + 'face_mask_regions_checkbox_group': 'FACE MASK REGIONS', + 'face_mask_types_checkbox_group': 'FACE MASK TYPES', + 'face_selector_age_range_slider': 'FACE SELECTOR AGE', + 'face_selector_gender_dropdown': 'FACE SELECTOR GENDER', + 'face_selector_mode_dropdown': 'FACE SELECTOR MODE', + 'face_selector_order_dropdown': 'FACE SELECTOR ORDER', + 'face_selector_race_dropdown': 'FACE SELECTOR RACE', + 'face_swapper_model_dropdown': 'FACE SWAPPER MODEL', + 'face_swapper_pixel_boost_dropdown': 'FACE SWAPPER PIXEL BOOST', + 'face_occluder_model_dropdown': 'FACE OCCLUDER MODEL', + 'face_parser_model_dropdown': 'FACE PARSER MODEL', + 'frame_colorizer_blend_slider': 'FRAME COLORIZER BLEND', + 'frame_colorizer_model_dropdown': 'FRAME COLORIZER MODEL', + 'frame_colorizer_size_dropdown': 'FRAME COLORIZER SIZE', + 'frame_enhancer_blend_slider': 'FRAME ENHANCER BLEND', + 'frame_enhancer_model_dropdown': 'FRAME ENHANCER MODEL', + 'job_list_status_checkbox_group': 'JOB STATUS', + 'job_manager_job_action_dropdown': 'JOB_ACTION', + 'job_manager_job_id_dropdown': 'JOB ID', + 'job_manager_step_index_dropdown': 'STEP INDEX', + 'job_runner_job_action_dropdown': 'JOB ACTION', + 'job_runner_job_id_dropdown': 'JOB ID', + 'lip_syncer_model_dropdown': 'LIP SYNCER MODEL', + 'lip_syncer_weight_slider': 'LIP SYNCER WEIGHT', + 'log_level_dropdown': 'LOG LEVEL', + 'output_audio_encoder_dropdown': 'OUTPUT AUDIO ENCODER', + 'output_audio_quality_slider': 'OUTPUT AUDIO QUALITY', + 'output_audio_volume_slider': 'OUTPUT AUDIO VOLUME', + 'output_image_or_video': 'OUTPUT', + 'output_image_quality_slider': 'OUTPUT IMAGE QUALITY', + 'output_image_resolution_dropdown': 'OUTPUT IMAGE RESOLUTION', + 'output_path_textbox': 'OUTPUT PATH', + 'output_video_encoder_dropdown': 'OUTPUT VIDEO ENCODER', + 'output_video_fps_slider': 'OUTPUT VIDEO FPS', + 'output_video_preset_dropdown': 'OUTPUT VIDEO PRESET', + 'output_video_quality_slider': 'OUTPUT VIDEO QUALITY', + 'output_video_resolution_dropdown': 'OUTPUT VIDEO RESOLUTION', + 'preview_frame_slider': 'PREVIEW FRAME', + 'preview_image': 'PREVIEW', + 'processors_checkbox_group': 'PROCESSORS', + 'reference_face_distance_slider': 'REFERENCE FACE DISTANCE', + 'reference_face_gallery': 'REFERENCE FACE', + 'refresh_button': 'REFRESH', + 'source_file': 'SOURCE', + 'start_button': 'START', + 'stop_button': 'STOP', + 'system_memory_limit_slider': 'SYSTEM MEMORY LIMIT', + 'target_file': 'TARGET', + 'temp_frame_format_dropdown': 'TEMP FRAME FORMAT', + 'terminal_textbox': 'TERMINAL', + 'trim_frame_slider': 'TRIM FRAME', + 'ui_workflow': 'UI WORKFLOW', + 'video_memory_strategy_dropdown': 'VIDEO MEMORY STRATEGY', + 'webcam_fps_slider': 'WEBCAM FPS', + 'webcam_image': 'WEBCAM', + 'webcam_device_id_dropdown': 'WEBCAM DEVICE ID', + 'webcam_mode_radio': 'WEBCAM MODE', + 'webcam_resolution_dropdown': 'WEBCAM RESOLUTION' + } +} + + +def get(notation : str) -> Optional[str]: + current = WORDING + + for fragment in notation.split('.'): + if fragment in current: + current = current.get(fragment) + if isinstance(current, str): + return current + + return None diff --git a/install.py b/install.py new file mode 100644 index 0000000000000000000000000000000000000000..000f1e72b9457cc04fae4212479dc2ae10984fa2 --- /dev/null +++ b/install.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +import os + +os.environ['SYSTEM_VERSION_COMPAT'] = '0' + +from facefusion import installer + +if __name__ == '__main__': + installer.cli() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000000000000000000000000000000000..64218bc23688632a08c98ec4a0451ed46f8ed5e5 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,7 @@ +[mypy] +check_untyped_defs = True +disallow_any_generics = True +disallow_untyped_calls = True +disallow_untyped_defs = True +ignore_missing_imports = True +strict_optional = False diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..77c061a43b9eee1710de5a0ae1d8ad9462c9fb26 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +gradio-rangeslider==0.0.8 +gradio==5.25.2 +numpy==2.2.4 +onnx==1.17.0 +onnxruntime==1.22.0 +opencv-python==4.11.0.86 +psutil==7.0.0 +tqdm==4.67.1 +scipy==1.15.2 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/helper.py b/tests/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..8902643c40f06f517355a6359cb192e42e0702ab --- /dev/null +++ b/tests/helper.py @@ -0,0 +1,44 @@ +import os +import tempfile + +from facefusion.filesystem import create_directory, is_directory, is_file, remove_directory +from facefusion.types import JobStatus + + +def is_test_job_file(file_path : str, job_status : JobStatus) -> bool: + return is_file(get_test_job_file(file_path, job_status)) + + +def get_test_job_file(file_path : str, job_status : JobStatus) -> str: + return os.path.join(get_test_jobs_directory(), job_status, file_path) + + +def get_test_jobs_directory() -> str: + return os.path.join(tempfile.gettempdir(), 'facefusion-test-jobs') + + +def get_test_example_file(file_path : str) -> str: + return os.path.join(get_test_examples_directory(), file_path) + + +def get_test_examples_directory() -> str: + return os.path.join(tempfile.gettempdir(), 'facefusion-test-examples') + + +def is_test_output_file(file_path : str) -> bool: + return is_file(get_test_output_file(file_path)) + + +def get_test_output_file(file_path : str) -> str: + return os.path.join(get_test_outputs_directory(), file_path) + + +def get_test_outputs_directory() -> str: + return os.path.join(tempfile.gettempdir(), 'facefusion-test-outputs') + + +def prepare_test_output_directory() -> bool: + test_outputs_directory = get_test_outputs_directory() + remove_directory(test_outputs_directory) + create_directory(test_outputs_directory) + return is_directory(test_outputs_directory) diff --git a/tests/test_audio.py b/tests/test_audio.py new file mode 100644 index 0000000000000000000000000000000000000000..36faf9b035c8aae2c3d99b177e75218823d1a711 --- /dev/null +++ b/tests/test_audio.py @@ -0,0 +1,28 @@ +import subprocess + +import pytest + +from facefusion.audio import get_audio_frame, read_static_audio +from facefusion.download import conditional_download +from .helper import get_test_example_file, get_test_examples_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.mp3' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.mp3'), get_test_example_file('source.wav') ]) + + +def test_get_audio_frame() -> None: + assert hasattr(get_audio_frame(get_test_example_file('source.mp3'), 25), '__array_interface__') + assert hasattr(get_audio_frame(get_test_example_file('source.wav'), 25), '__array_interface__') + assert get_audio_frame('invalid', 25) is None + + +def test_read_static_audio() -> None: + assert len(read_static_audio(get_test_example_file('source.mp3'), 25)) == 280 + assert len(read_static_audio(get_test_example_file('source.wav'), 25)) == 280 + assert read_static_audio('invalid', 25) is None diff --git a/tests/test_cli_age_modifier.py b/tests/test_cli_age_modifier.py new file mode 100644 index 0000000000000000000000000000000000000000..2184ac8840a54615584799816d92b0fb3ab32dce --- /dev/null +++ b/tests/test_cli_age_modifier.py @@ -0,0 +1,38 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_modify_age_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'age_modifier', '--age-modifier-direction', '100', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-age-face-to-image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-age-face-to-image.jpg') is True + + +def test_modify_age_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'age_modifier', '--age-modifier-direction', '100', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-age-face-to-video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-age-face-to-video.mp4') is True diff --git a/tests/test_cli_batch_runner.py b/tests/test_cli_batch_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..963e5a95494835258f082c11374d760fa7965282 --- /dev/null +++ b/tests/test_cli_batch_runner.py @@ -0,0 +1,45 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p-batch-1.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '2', get_test_example_file('target-240p-batch-2.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_batch_run_targets() -> None: + commands = [ sys.executable, 'facefusion.py', 'batch-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p-batch-*.jpg'), '-o', get_test_output_file('test-batch-run-targets-{index}.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-batch-run-targets-0.jpg') is True + assert is_test_output_file('test-batch-run-targets-1.jpg') is True + assert is_test_output_file('test-batch-run-targets-2.jpg') is False + + +def test_batch_run_sources_to_targets() -> None: + commands = [ sys.executable, 'facefusion.py', 'batch-run', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('target-240p-batch-*.jpg'), '-t', get_test_example_file('target-240p-batch-*.jpg'), '-o', get_test_output_file('test-batch-run-sources-to-targets-{index}.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-batch-run-sources-to-targets-0.jpg') is True + assert is_test_output_file('test-batch-run-sources-to-targets-1.jpg') is True + assert is_test_output_file('test-batch-run-sources-to-targets-2.jpg') is True + assert is_test_output_file('test-batch-run-sources-to-targets-3.jpg') is True + assert is_test_output_file('test-batch-run-sources-to-targets-4.jpg') is False diff --git a/tests/test_cli_expression_restorer.py b/tests/test_cli_expression_restorer.py new file mode 100644 index 0000000000000000000000000000000000000000..236cf78f3b3583e5ba46de370ebab8dc2f871282 --- /dev/null +++ b/tests/test_cli_expression_restorer.py @@ -0,0 +1,38 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_restore_expression_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'expression_restorer', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-restore-expression-to-image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-restore-expression-to-image.jpg') is True + + +def test_restore_expression_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'expression_restorer', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-restore-expression-to-video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-restore-expression-to-video.mp4') is True diff --git a/tests/test_cli_face_debugger.py b/tests/test_cli_face_debugger.py new file mode 100644 index 0000000000000000000000000000000000000000..b393881d4f485c65d1781c682c2b83c5a8ef6d15 --- /dev/null +++ b/tests/test_cli_face_debugger.py @@ -0,0 +1,39 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_debug_face_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-debug-face-to-image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-debug-face-to-image.jpg') is True + + +def test_debug_face_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-debug-face-to-video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-debug-face-to-video.mp4') is True diff --git a/tests/test_cli_face_editor.py b/tests/test_cli_face_editor.py new file mode 100644 index 0000000000000000000000000000000000000000..27b289ec8adbacdb8be8073b5098d74ea11061f2 --- /dev/null +++ b/tests/test_cli_face_editor.py @@ -0,0 +1,39 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_edit_face_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_editor', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-edit-face-to-image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-edit-face-to-image.jpg') is True + + +def test_edit_face_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_editor', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-edit-face-to-video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-edit-face-to-video.mp4') is True diff --git a/tests/test_cli_face_enhancer.py b/tests/test_cli_face_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..c1e5128f8c58230ff7ec9b26e31892b6ba60d093 --- /dev/null +++ b/tests/test_cli_face_enhancer.py @@ -0,0 +1,39 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_enhance_face_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_enhancer', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-enhance-face-to-image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-enhance-face-to-image.jpg') is True + + +def test_enhance_face_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_enhancer', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-enhance-face-to-video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-enhance-face-to-video.mp4') is True diff --git a/tests/test_cli_face_swapper.py b/tests/test_cli_face_swapper.py new file mode 100644 index 0000000000000000000000000000000000000000..be68cde56e3f3ab2d19479917a4db0d24c6cb98c --- /dev/null +++ b/tests/test_cli_face_swapper.py @@ -0,0 +1,39 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_swap_face_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_swapper', '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-swap-face-to-image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-swap-face-to-image.jpg') is True + + +def test_swap_face_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_swapper', '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-swap-face-to-video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-swap-face-to-video.mp4') is True diff --git a/tests/test_cli_frame_colorizer.py b/tests/test_cli_frame_colorizer.py new file mode 100644 index 0000000000000000000000000000000000000000..38774c0a4d1ae334417baeca8dd9e2379e62c694 --- /dev/null +++ b/tests/test_cli_frame_colorizer.py @@ -0,0 +1,40 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', '-vf', 'hue=s=0', get_test_example_file('target-240p-0sat.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vf', 'hue=s=0', get_test_example_file('target-240p-0sat.mp4') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_colorize_frame_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'frame_colorizer', '-t', get_test_example_file('target-240p-0sat.jpg'), '-o', get_test_output_file('test_colorize-frame-to-image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test_colorize-frame-to-image.jpg') is True + + +def test_colorize_frame_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'frame_colorizer', '-t', get_test_example_file('target-240p-0sat.mp4'), '-o', get_test_output_file('test-colorize-frame-to-video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-colorize-frame-to-video.mp4') is True diff --git a/tests/test_cli_frame_enhancer.py b/tests/test_cli_frame_enhancer.py new file mode 100644 index 0000000000000000000000000000000000000000..7892fa2dbe28cb2a31d3f8d38eb885195b348335 --- /dev/null +++ b/tests/test_cli_frame_enhancer.py @@ -0,0 +1,39 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_enhance_frame_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'frame_enhancer', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-enhance-frame-to-image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-enhance-frame-to-image.jpg') is True + + +def test_enhance_frame_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'frame_enhancer', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-enhance-frame-to-video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test-enhance-frame-to-video.mp4') is True diff --git a/tests/test_cli_job_manager.py b/tests/test_cli_job_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..a20a27283e67fc2fcee78bd1ba5f8b0876809ea3 --- /dev/null +++ b/tests/test_cli_job_manager.py @@ -0,0 +1,209 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, count_step_total, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_job_file + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + + +@pytest.mark.skip() +def test_job_list() -> None: + pass + + +def test_job_create() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-create', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_job_file('test-job-create.json', 'drafted') is True + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-create', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 1 + + +def test_job_submit() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-submit', 'test-job-submit', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-submit', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-submit', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-submit', 'test-job-submit', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_job_file('test-job-submit.json', 'queued') is True + assert subprocess.run(commands).returncode == 1 + + +def test_submit_all() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-submit-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-submit-all-1', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-submit-all-2', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-submit-all-1', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-submit-all-2', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-submit-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_job_file('test-job-submit-all-1.json', 'queued') is True + assert is_test_job_file('test-job-submit-all-2.json', 'queued') is True + assert subprocess.run(commands).returncode == 1 + + +def test_job_delete() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-delete', 'test-job-delete', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-delete', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-delete', 'test-job-delete', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_job_file('test-job-delete.json', 'drafted') is False + assert subprocess.run(commands).returncode == 1 + + +def test_job_delete_all() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-delete-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-delete-all-1', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-delete-all-2', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-delete-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_job_file('test-job-delete-all-1.json', 'drafted') is False + assert is_test_job_file('test-job-delete-all-2.json', 'drafted') is False + assert subprocess.run(commands).returncode == 1 + + +def test_job_add_step() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-add-step', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + + assert subprocess.run(commands).returncode == 1 + assert count_step_total('test-job-add-step') == 0 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-add-step', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-add-step', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert count_step_total('test-job-add-step') == 1 + + +def test_job_remix() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-remix-step', 'test-job-remix-step', '0', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + + assert subprocess.run(commands).returncode == 1 + assert count_step_total('test-job-remix-step') == 0 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-remix-step', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-remix-step', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-remix-step', 'test-job-remix-step', '0', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + + assert count_step_total('test-job-remix-step') == 1 + assert subprocess.run(commands).returncode == 0 + assert count_step_total('test-job-remix-step') == 2 + + commands = [ sys.executable, 'facefusion.py', 'job-remix-step', 'test-job-remix-step', '-1', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert count_step_total('test-job-remix-step') == 3 + + +def test_job_insert_step() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-insert-step', 'test-job-insert-step', '0', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + + assert subprocess.run(commands).returncode == 1 + assert count_step_total('test-job-insert-step') == 0 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-insert-step', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-insert-step', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-insert-step', 'test-job-insert-step', '0', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + + assert count_step_total('test-job-insert-step') == 1 + assert subprocess.run(commands).returncode == 0 + assert count_step_total('test-job-insert-step') == 2 + + commands = [ sys.executable, 'facefusion.py', 'job-insert-step', 'test-job-insert-step', '-1', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert count_step_total('test-job-insert-step') == 3 + + +def test_job_remove_step() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-remove-step', 'test-job-remove-step', '0', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-remove-step', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-remove-step', '--jobs-path', get_test_jobs_directory(), '-s', get_test_example_file('source.jpg'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-remix-step.jpg') ] + subprocess.run(commands) + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-remove-step', 'test-job-remove-step', '0', '--jobs-path', get_test_jobs_directory() ] + + assert count_step_total('test-job-remove-step') == 2 + assert subprocess.run(commands).returncode == 0 + assert count_step_total('test-job-remove-step') == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-remove-step', 'test-job-remove-step', '-1', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 0 + assert subprocess.run(commands).returncode == 1 + assert count_step_total('test-job-remove-step') == 0 diff --git a/tests/test_cli_job_runner.py b/tests/test_cli_job_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..23d23cbacd166e7ecd5dec0caac2257e72b48ad3 --- /dev/null +++ b/tests/test_cli_job_runner.py @@ -0,0 +1,147 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs, move_job_file, set_steps_status +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_job_run() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-run', 'test-job-run', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-run', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-run.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-run', 'test-job-run', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-submit', 'test-job-run', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-run', 'test-job-run', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 0 + assert subprocess.run(commands).returncode == 1 + assert is_test_output_file('test-job-run.jpg') is True + + +def test_job_run_all() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-run-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-run-all-1', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-run-all-2', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-run-all-1', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-run-all-1.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-run-all-2', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-job-run-all-2.mp4'), '--trim-frame-end', '1' ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-run-all-2', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-job-run-all-2.mp4'), '--trim-frame-start', '0', '--trim-frame-end', '1' ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-run-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-submit-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-run-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 0 + assert subprocess.run(commands).returncode == 1 + assert is_test_output_file('test-job-run-all-1.jpg') is True + assert is_test_output_file('test-job-run-all-2.mp4') is True + + +def test_job_retry() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-retry', 'test-job-retry', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-retry', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-retry', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-retry.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-retry', 'test-job-retry', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 1 + + set_steps_status('test-job-retry', 'failed') + move_job_file('test-job-retry', 'failed') + + commands = [ sys.executable, 'facefusion.py', 'job-retry', 'test-job-retry', '--jobs-path', get_test_jobs_directory() ] + + assert subprocess.run(commands).returncode == 0 + assert subprocess.run(commands).returncode == 1 + assert is_test_output_file('test-job-retry.jpg') is True + + +def test_job_retry_all() -> None: + commands = [ sys.executable, 'facefusion.py', 'job-retry-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 1 + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-retry-all-1', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-create', 'test-job-retry-all-2', '--jobs-path', get_test_jobs_directory() ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-retry-all-1', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test-job-retry-all-1.jpg') ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-retry-all-2', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-job-retry-all-2.mp4'), '--trim-frame-end', '1' ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-add-step', 'test-job-retry-all-2', '--jobs-path', get_test_jobs_directory(), '--processors', 'face_debugger', '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test-job-retry-all-2.mp4'), '--trim-frame-start', '0', '--trim-frame-end', '1' ] + subprocess.run(commands) + + commands = [ sys.executable, 'facefusion.py', 'job-retry-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 1 + + set_steps_status('test-job-retry-all-1', 'failed') + set_steps_status('test-job-retry-all-2', 'failed') + move_job_file('test-job-retry-all-1', 'failed') + move_job_file('test-job-retry-all-2', 'failed') + + commands = [ sys.executable, 'facefusion.py', 'job-retry-all', '--jobs-path', get_test_jobs_directory(), '--halt-on-error' ] + + assert subprocess.run(commands).returncode == 0 + assert subprocess.run(commands).returncode == 1 + assert is_test_output_file('test-job-retry-all-1.jpg') is True + assert is_test_output_file('test-job-retry-all-2.mp4') is True diff --git a/tests/test_cli_lip_syncer.py b/tests/test_cli_lip_syncer.py new file mode 100644 index 0000000000000000000000000000000000000000..c0e5cc33e1c9ea75f86e8a123b762c2572884ea7 --- /dev/null +++ b/tests/test_cli_lip_syncer.py @@ -0,0 +1,40 @@ +import subprocess +import sys + +import pytest + +from facefusion.download import conditional_download +from facefusion.jobs.job_manager import clear_jobs, init_jobs +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.mp3', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def test_sync_lip_to_image() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'lip_syncer', '-s', get_test_example_file('source.mp3'), '-t', get_test_example_file('target-240p.jpg'), '-o', get_test_output_file('test_sync_lip_to_image.jpg') ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test_sync_lip_to_image.jpg') is True + + +def test_sync_lip_to_video() -> None: + commands = [ sys.executable, 'facefusion.py', 'headless-run', '--jobs-path', get_test_jobs_directory(), '--processors', 'lip_syncer', '-s', get_test_example_file('source.mp3'), '-t', get_test_example_file('target-240p.mp4'), '-o', get_test_output_file('test_sync_lip_to_video.mp4'), '--trim-frame-end', '1' ] + + assert subprocess.run(commands).returncode == 0 + assert is_test_output_file('test_sync_lip_to_video.mp4') is True diff --git a/tests/test_common_helper.py b/tests/test_common_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7dcfdb142d6828364033fd57b0e00d41aa8549 --- /dev/null +++ b/tests/test_common_helper.py @@ -0,0 +1,27 @@ +from facefusion.common_helper import calc_float_step, calc_int_step, create_float_metavar, create_float_range, create_int_metavar, create_int_range + + +def test_create_int_metavar() -> None: + assert create_int_metavar([ 1, 2, 3, 4, 5 ]) == '[1..5:1]' + + +def test_create_float_metavar() -> None: + assert create_float_metavar([ 0.1, 0.2, 0.3, 0.4, 0.5 ]) == '[0.1..0.5:0.1]' + + +def test_create_int_range() -> None: + assert create_int_range(0, 2, 1) == [ 0, 1, 2 ] + assert create_float_range(0, 1, 1) == [ 0, 1 ] + + +def test_create_float_range() -> None: + assert create_float_range(0.0, 1.0, 0.5) == [ 0.0, 0.5, 1.0 ] + assert create_float_range(0.0, 1.0, 0.05) == [ 0.0, 0.05, 0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95, 1.0 ] + + +def test_calc_int_step() -> None: + assert calc_int_step([ 0, 1 ]) == 1 + + +def test_calc_float_step() -> None: + assert calc_float_step([ 0.1, 0.2 ]) == 0.1 diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000000000000000000000000000000000000..cae75778252b28c5ac1b61bc56a62b6fc8b02982 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,85 @@ +from configparser import ConfigParser + +import pytest + +from facefusion import config + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + config.CONFIG_PARSER = ConfigParser() + config.CONFIG_PARSER.read_dict( + { + 'str': + { + 'valid': 'a', + 'unset': '' + }, + 'int': + { + 'valid': '1', + 'unset': '' + }, + 'float': + { + 'valid': '1.0', + 'unset': '' + }, + 'bool': + { + 'valid': 'True', + 'unset': '' + }, + 'str_list': + { + 'valid': 'a b c', + 'unset': '' + }, + 'int_list': + { + 'valid': '1 2 3', + 'unset': '' + } + }) + + +def test_get_str_value() -> None: + assert config.get_str_value('str', 'valid') == 'a' + assert config.get_str_value('str', 'unset', 'b') == 'b' + assert config.get_str_value('str', 'unset') is None + assert config.get_str_value('str', 'invalid') is None + + +def test_get_int_value() -> None: + assert config.get_int_value('int', 'valid') == 1 + assert config.get_int_value('int', 'unset', '1') == 1 + assert config.get_int_value('int', 'unset') is None + assert config.get_int_value('int', 'invalid') is None + + +def test_get_float_value() -> None: + assert config.get_float_value('float', 'valid') == 1.0 + assert config.get_float_value('float', 'unset', '1.0') == 1.0 + assert config.get_float_value('float', 'unset') is None + assert config.get_float_value('float', 'invalid') is None + + +def test_get_bool_value() -> None: + assert config.get_bool_value('bool', 'valid') is True + assert config.get_bool_value('bool', 'unset', 'False') is False + assert config.get_bool_value('bool', 'unset') is None + assert config.get_bool_value('bool', 'invalid') is None + + +def test_get_str_list() -> None: + assert config.get_str_list('str_list', 'valid') == [ 'a', 'b', 'c' ] + assert config.get_str_list('str_list', 'unset', 'c b a') == [ 'c', 'b', 'a' ] + assert config.get_str_list('str_list', 'unset') is None + assert config.get_str_list('str_list', 'invalid') is None + + +def test_get_int_list() -> None: + assert config.get_int_list('int_list', 'valid') == [ 1, 2, 3 ] + assert config.get_int_list('int_list', 'unset', '3 2 1') == [ 3, 2, 1 ] + assert config.get_int_list('int_list', 'unset') is None + assert config.get_int_list('int_list', 'invalid') is None diff --git a/tests/test_curl_builder.py b/tests/test_curl_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..ca0ec5b927d2b54020b5c7c7d3c2618dd8cd4628 --- /dev/null +++ b/tests/test_curl_builder.py @@ -0,0 +1,14 @@ +from shutil import which + +from facefusion import metadata +from facefusion.curl_builder import chain, head, run + + +def test_run() -> None: + user_agent = metadata.get('name') + '/' + metadata.get('version') + + assert run([]) == [ which('curl'), '--user-agent', user_agent, '--insecure', '--location', '--silent' ] + + +def test_chain() -> None: + assert chain(head(metadata.get('url'))) == [ '-I', metadata.get('url') ] diff --git a/tests/test_date_helper.py b/tests/test_date_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..e8d0cd0a8fe8150e0105f811108ea0023acb700a --- /dev/null +++ b/tests/test_date_helper.py @@ -0,0 +1,15 @@ +from datetime import datetime, timedelta + +from facefusion.date_helper import describe_time_ago + + +def get_time_ago(days : int, hours : int, minutes : int) -> datetime: + previous_time = datetime.now() - timedelta(days = days, hours = hours, minutes = minutes) + return previous_time.astimezone() + + +def test_describe_time_ago() -> None: + assert describe_time_ago(get_time_ago(0, 0, 0)) == 'just now' + assert describe_time_ago(get_time_ago(0, 0, 10)) == '10 minutes ago' + assert describe_time_ago(get_time_ago(0, 5, 10)) == '5 hours and 10 minutes ago' + assert describe_time_ago(get_time_ago(1, 5, 10)) == '1 days, 5 hours and 10 minutes ago' diff --git a/tests/test_download.py b/tests/test_download.py new file mode 100644 index 0000000000000000000000000000000000000000..48698aad879a5ff8a9743ca49b5bb8247b3233f2 --- /dev/null +++ b/tests/test_download.py @@ -0,0 +1,18 @@ +from facefusion.download import get_static_download_size, ping_static_url, resolve_download_url_by_provider + + +def test_get_static_download_size() -> None: + assert get_static_download_size('https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.onnx') == 85170772 + assert get_static_download_size('https://huggingface.co/facefusion/models-3.0.0/resolve/main/fairface.onnx') == 85170772 + assert get_static_download_size('invalid') == 0 + + +def test_static_ping_url() -> None: + assert ping_static_url('https://github.com') is True + assert ping_static_url('https://huggingface.co') is True + assert ping_static_url('invalid') is False + + +def test_resolve_download_url_by_provider() -> None: + assert resolve_download_url_by_provider('github', 'models-3.0.0', 'fairface.onnx') == 'https://github.com/facefusion/facefusion-assets/releases/download/models-3.0.0/fairface.onnx' + assert resolve_download_url_by_provider('huggingface', 'models-3.0.0', 'fairface.onnx') == 'https://huggingface.co/facefusion/models-3.0.0/resolve/main/fairface.onnx' diff --git a/tests/test_execution.py b/tests/test_execution.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4a8c1a96d0a86756ff4cb0d18c6f1c7a0b40be --- /dev/null +++ b/tests/test_execution.py @@ -0,0 +1,24 @@ +from facefusion.execution import create_inference_session_providers, get_available_execution_providers, has_execution_provider + + +def test_has_execution_provider() -> None: + assert has_execution_provider('cpu') is True + assert has_execution_provider('openvino') is False + + +def test_get_available_execution_providers() -> None: + assert 'cpu' in get_available_execution_providers() + + +def test_create_inference_session_providers() -> None: + inference_session_providers =\ + [ + ('CUDAExecutionProvider', + { + 'device_id': '1', + 'cudnn_conv_algo_search': 'EXHAUSTIVE' + }), + 'CPUExecutionProvider' + ] + + assert create_inference_session_providers('1', [ 'cpu', 'cuda' ]) == inference_session_providers diff --git a/tests/test_face_analyser.py b/tests/test_face_analyser.py new file mode 100644 index 0000000000000000000000000000000000000000..86926846260e7900dcf52ce69ab9933dec878668 --- /dev/null +++ b/tests/test_face_analyser.py @@ -0,0 +1,113 @@ +import subprocess + +import pytest + +from facefusion import face_classifier, face_detector, face_landmarker, face_recognizer, state_manager +from facefusion.download import conditional_download +from facefusion.face_analyser import get_many_faces, get_one_face +from facefusion.types import Face +from facefusion.vision import read_static_image +from .helper import get_test_example_file, get_test_examples_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.8:ih*0.8', get_test_example_file('source-80crop.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.7:ih*0.7', get_test_example_file('source-70crop.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.6:ih*0.6', get_test_example_file('source-60crop.jpg') ]) + state_manager.init_item('execution_device_id', '0') + state_manager.init_item('execution_providers', [ 'cpu' ]) + state_manager.init_item('download_providers', [ 'github' ]) + state_manager.init_item('face_detector_angles', [ 0 ]) + state_manager.init_item('face_detector_model', 'many') + state_manager.init_item('face_detector_score', 0.5) + state_manager.init_item('face_landmarker_model', 'many') + state_manager.init_item('face_landmarker_score', 0.5) + face_classifier.pre_check() + face_landmarker.pre_check() + face_recognizer.pre_check() + + +@pytest.fixture(autouse = True) +def before_each() -> None: + face_classifier.clear_inference_pool() + face_detector.clear_inference_pool() + face_landmarker.clear_inference_pool() + face_recognizer.clear_inference_pool() + + +def test_get_one_face_with_retinaface() -> None: + state_manager.init_item('face_detector_model', 'retinaface') + state_manager.init_item('face_detector_size', '320x320') + face_detector.pre_check() + + source_paths =\ + [ + get_test_example_file('source.jpg'), + get_test_example_file('source-80crop.jpg'), + get_test_example_file('source-70crop.jpg'), + get_test_example_file('source-60crop.jpg') + ] + + for source_path in source_paths: + source_frame = read_static_image(source_path) + many_faces = get_many_faces([ source_frame ]) + face = get_one_face(many_faces) + + assert isinstance(face, Face) + + +def test_get_one_face_with_scrfd() -> None: + state_manager.init_item('face_detector_model', 'scrfd') + state_manager.init_item('face_detector_size', '640x640') + face_detector.pre_check() + + source_paths =\ + [ + get_test_example_file('source.jpg'), + get_test_example_file('source-80crop.jpg'), + get_test_example_file('source-70crop.jpg'), + get_test_example_file('source-60crop.jpg') + ] + + for source_path in source_paths: + source_frame = read_static_image(source_path) + many_faces = get_many_faces([ source_frame ]) + face = get_one_face(many_faces) + + assert isinstance(face, Face) + + +def test_get_one_face_with_yoloface() -> None: + state_manager.init_item('face_detector_model', 'yoloface') + state_manager.init_item('face_detector_size', '640x640') + face_detector.pre_check() + + source_paths =\ + [ + get_test_example_file('source.jpg'), + get_test_example_file('source-80crop.jpg'), + get_test_example_file('source-70crop.jpg'), + get_test_example_file('source-60crop.jpg') + ] + + for source_path in source_paths: + source_frame = read_static_image(source_path) + many_faces = get_many_faces([ source_frame ]) + face = get_one_face(many_faces) + + assert isinstance(face, Face) + + +def test_get_many_faces() -> None: + source_path = get_test_example_file('source.jpg') + source_frame = read_static_image(source_path) + many_faces = get_many_faces([ source_frame, source_frame, source_frame ]) + + assert isinstance(many_faces[0], Face) + assert isinstance(many_faces[1], Face) + assert isinstance(many_faces[2], Face) diff --git a/tests/test_ffmpeg.py b/tests/test_ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..677822b74f70f7a8dc96e9b33eb67850dca9ca51 --- /dev/null +++ b/tests/test_ffmpeg.py @@ -0,0 +1,187 @@ +import os +import subprocess +import tempfile + +import pytest + +import facefusion.ffmpeg +from facefusion import process_manager, state_manager +from facefusion.download import conditional_download +from facefusion.ffmpeg import concat_video, extract_frames, merge_video, read_audio_buffer, replace_audio, restore_audio +from facefusion.filesystem import copy_file +from facefusion.temp_helper import clear_temp_directory, create_temp_directory, get_temp_file_path, resolve_temp_frame_paths +from facefusion.types import EncoderSet +from .helper import get_test_example_file, get_test_examples_directory, get_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + process_manager.start() + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.mp3', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.mp3'), get_test_example_file('source.wav') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vf', 'fps=25', get_test_example_file('target-240p-25fps.mp4') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vf', 'fps=30', get_test_example_file('target-240p-30fps.mp4') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vf', 'fps=60', get_test_example_file('target-240p-60fps.mp4') ]) + + for output_video_format in [ 'avi', 'm4v', 'mkv', 'mov', 'mp4', 'webm' ]: + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.mp3'), '-i', get_test_example_file('target-240p.mp4'), '-ar', '16000', get_test_example_file('target-240p-16khz.' + output_video_format) ]) + + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.mp3'), '-i', get_test_example_file('target-240p.mp4'), '-ar', '48000', get_test_example_file('target-240p-48khz.mp4') ]) + state_manager.init_item('temp_path', tempfile.gettempdir()) + state_manager.init_item('temp_frame_format', 'png') + state_manager.init_item('output_audio_encoder', 'aac') + state_manager.init_item('output_audio_quality', 100) + state_manager.init_item('output_audio_volume', 100) + state_manager.init_item('output_video_encoder', 'libx264') + state_manager.init_item('output_video_quality', 100) + state_manager.init_item('output_video_preset', 'ultrafast') + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + prepare_test_output_directory() + + +def get_available_encoder_set() -> EncoderSet: + if os.getenv('CI'): + return\ + { + 'audio': [ 'aac' ], + 'video': [ 'libx264' ] + } + return facefusion.ffmpeg.get_available_encoder_set() + + +def test_get_available_encoder_set() -> None: + available_encoder_set = get_available_encoder_set() + + assert 'aac' in available_encoder_set.get('audio') + assert 'libx264' in available_encoder_set.get('video') + + +def test_extract_frames() -> None: + test_set =\ + [ + (get_test_example_file('target-240p-25fps.mp4'), 0, 270, 324), + (get_test_example_file('target-240p-25fps.mp4'), 224, 270, 55), + (get_test_example_file('target-240p-25fps.mp4'), 124, 224, 120), + (get_test_example_file('target-240p-25fps.mp4'), 0, 100, 120), + (get_test_example_file('target-240p-30fps.mp4'), 0, 324, 324), + (get_test_example_file('target-240p-30fps.mp4'), 224, 324, 100), + (get_test_example_file('target-240p-30fps.mp4'), 124, 224, 100), + (get_test_example_file('target-240p-30fps.mp4'), 0, 100, 100), + (get_test_example_file('target-240p-60fps.mp4'), 0, 648, 324), + (get_test_example_file('target-240p-60fps.mp4'), 224, 648, 212), + (get_test_example_file('target-240p-60fps.mp4'), 124, 224, 50), + (get_test_example_file('target-240p-60fps.mp4'), 0, 100, 50) + ] + + for target_path, trim_frame_start, trim_frame_end, frame_total in test_set: + create_temp_directory(target_path) + + assert extract_frames(target_path, '452x240', 30.0, trim_frame_start, trim_frame_end) is True + assert len(resolve_temp_frame_paths(target_path)) == frame_total + + clear_temp_directory(target_path) + + +def test_merge_video() -> None: + target_paths =\ + [ + get_test_example_file('target-240p-16khz.avi'), + get_test_example_file('target-240p-16khz.m4v'), + get_test_example_file('target-240p-16khz.mkv'), + get_test_example_file('target-240p-16khz.mp4'), + get_test_example_file('target-240p-16khz.mov'), + get_test_example_file('target-240p-16khz.webm') + ] + output_video_encoders = get_available_encoder_set().get('video') + + for target_path in target_paths: + for output_video_encoder in output_video_encoders: + state_manager.init_item('output_video_encoder', output_video_encoder) + create_temp_directory(target_path) + extract_frames(target_path, '452x240', 25.0, 0, 1) + + assert merge_video(target_path, 25.0, '452x240', 25.0, 0, 1) is True + + clear_temp_directory(target_path) + + state_manager.init_item('output_video_encoder', 'libx264') + + +def test_concat_video() -> None: + output_path = get_test_output_file('test-concat-video.mp4') + temp_output_paths =\ + [ + get_test_example_file('target-240p-16khz.mp4'), + get_test_example_file('target-240p-16khz.mp4') + ] + + assert concat_video(output_path, temp_output_paths) is True + + +def test_read_audio_buffer() -> None: + assert isinstance(read_audio_buffer(get_test_example_file('source.mp3'), 1, 16, 1), bytes) + assert isinstance(read_audio_buffer(get_test_example_file('source.wav'), 1, 16, 1), bytes) + assert read_audio_buffer(get_test_example_file('invalid.mp3'), 1, 16, 1) is None + + +def test_restore_audio() -> None: + test_set =\ + [ + (get_test_example_file('target-240p-16khz.avi'), get_test_output_file('target-240p-16khz.avi')), + (get_test_example_file('target-240p-16khz.m4v'), get_test_output_file('target-240p-16khz.m4v')), + (get_test_example_file('target-240p-16khz.mkv'), get_test_output_file('target-240p-16khz.mkv')), + (get_test_example_file('target-240p-16khz.mov'), get_test_output_file('target-240p-16khz.mov')), + (get_test_example_file('target-240p-16khz.mp4'), get_test_output_file('target-240p-16khz.mp4')), + (get_test_example_file('target-240p-48khz.mp4'), get_test_output_file('target-240p-48khz.mp4')), + (get_test_example_file('target-240p-16khz.webm'), get_test_output_file('target-240p-16khz.webm')) + ] + output_audio_encoders = get_available_encoder_set().get('audio') + + for target_path, output_path in test_set: + create_temp_directory(target_path) + + for output_audio_encoder in output_audio_encoders: + state_manager.init_item('output_audio_encoder', output_audio_encoder) + copy_file(target_path, get_temp_file_path(target_path)) + + assert restore_audio(target_path, output_path, 0, 270) is True + + clear_temp_directory(target_path) + + state_manager.init_item('output_audio_encoder', 'aac') + + +def test_replace_audio() -> None: + test_set =\ + [ + (get_test_example_file('target-240p-16khz.avi'), get_test_output_file('target-240p-16khz.avi')), + (get_test_example_file('target-240p-16khz.m4v'), get_test_output_file('target-240p-16khz.m4v')), + (get_test_example_file('target-240p-16khz.mkv'), get_test_output_file('target-240p-16khz.mkv')), + (get_test_example_file('target-240p-16khz.mov'), get_test_output_file('target-240p-16khz.mov')), + (get_test_example_file('target-240p-16khz.mp4'), get_test_output_file('target-240p-16khz.mp4')), + (get_test_example_file('target-240p-48khz.mp4'), get_test_output_file('target-240p-48khz.mp4')), + (get_test_example_file('target-240p-16khz.webm'), get_test_output_file('target-240p-16khz.webm')) + ] + output_audio_encoders = get_available_encoder_set().get('audio') + + for target_path, output_path in test_set: + create_temp_directory(target_path) + + for output_audio_encoder in output_audio_encoders: + state_manager.init_item('output_audio_encoder', output_audio_encoder) + copy_file(target_path, get_temp_file_path(target_path)) + + assert replace_audio(target_path, get_test_example_file('source.mp3'), output_path) is True + assert replace_audio(target_path, get_test_example_file('source.wav'), output_path) is True + + clear_temp_directory(target_path) + + state_manager.init_item('output_audio_encoder', 'aac') diff --git a/tests/test_ffmpeg_builder.py b/tests/test_ffmpeg_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9179b88da66b21458dd341192d7c8585ab376877 --- /dev/null +++ b/tests/test_ffmpeg_builder.py @@ -0,0 +1,83 @@ +from shutil import which + +from facefusion import ffmpeg_builder +from facefusion.ffmpeg_builder import chain, run, select_frame_range, set_audio_quality, set_audio_sample_size, set_stream_mode, set_video_quality + + +def test_run() -> None: + assert run([]) == [ which('ffmpeg'), '-loglevel', 'error' ] + + +def test_chain() -> None: + assert chain(ffmpeg_builder.set_progress()) == [ '-progress' ] + + +def test_set_stream_mode() -> None: + assert set_stream_mode('udp') == [ '-f', 'mpegts' ] + assert set_stream_mode('v4l2') == [ '-f', 'v4l2' ] + + +def test_select_frame_range() -> None: + assert select_frame_range(0, None, 30) == [ '-vf', 'trim=start_frame=0,fps=30' ] + assert select_frame_range(None, 100, 30) == [ '-vf', 'trim=end_frame=100,fps=30' ] + assert select_frame_range(0, 100, 30) == [ '-vf', 'trim=start_frame=0:end_frame=100,fps=30' ] + assert select_frame_range(None, None, 30) == [ '-vf', 'fps=30' ] + + +def test_set_audio_sample_size() -> None: + assert set_audio_sample_size(16) == [ '-f', 's16le' ] + assert set_audio_sample_size(32) == [ '-f', 's32le' ] + + +def test_set_audio_quality() -> None: + assert set_audio_quality('aac', 0) == [ '-q:a', '0.1' ] + assert set_audio_quality('aac', 50) == [ '-q:a', '1.0' ] + assert set_audio_quality('aac', 100) == [ '-q:a', '2.0' ] + assert set_audio_quality('libmp3lame', 0) == [ '-q:a', '9' ] + assert set_audio_quality('libmp3lame', 50) == [ '-q:a', '4' ] + assert set_audio_quality('libmp3lame', 100) == [ '-q:a', '0' ] + assert set_audio_quality('libopus', 0) == [ '-b:a', '64k' ] + assert set_audio_quality('libopus', 50) == [ '-b:a', '160k' ] + assert set_audio_quality('libopus', 100) == [ '-b:a', '256k' ] + assert set_audio_quality('libvorbis', 0) == [ '-q:a', '-1.0' ] + assert set_audio_quality('libvorbis', 50) == [ '-q:a', '4.5' ] + assert set_audio_quality('libvorbis', 100) == [ '-q:a', '10.0' ] + assert set_audio_quality('flac', 0) == [] + assert set_audio_quality('flac', 50) == [] + assert set_audio_quality('flac', 100) == [] + + +def test_set_video_quality() -> None: + assert set_video_quality('libx264', 0) == [ '-crf', '51' ] + assert set_video_quality('libx264', 50) == [ '-crf', '26' ] + assert set_video_quality('libx264', 100) == [ '-crf', '0' ] + assert set_video_quality('libx265', 0) == [ '-crf', '51' ] + assert set_video_quality('libx265', 50) == [ '-crf', '26' ] + assert set_video_quality('libx265', 100) == [ '-crf', '0' ] + assert set_video_quality('libvpx-vp9', 0) == [ '-crf', '63' ] + assert set_video_quality('libvpx-vp9', 50) == [ '-crf', '32' ] + assert set_video_quality('libvpx-vp9', 100) == [ '-crf', '0' ] + assert set_video_quality('h264_nvenc', 0) == [ '-cq' , '51' ] + assert set_video_quality('h264_nvenc', 50) == [ '-cq' , '26' ] + assert set_video_quality('h264_nvenc', 100) == [ '-cq' , '0' ] + assert set_video_quality('hevc_nvenc', 0) == [ '-cq' , '51' ] + assert set_video_quality('hevc_nvenc', 50) == [ '-cq' , '26' ] + assert set_video_quality('hevc_nvenc', 100) == [ '-cq' , '0' ] + assert set_video_quality('h264_amf', 0) == [ '-qp_i', '51', '-qp_p', '51', '-qp_b', '51' ] + assert set_video_quality('h264_amf', 50) == [ '-qp_i', '26', '-qp_p', '26', '-qp_b', '26' ] + assert set_video_quality('h264_amf', 100) == [ '-qp_i', '0', '-qp_p', '0', '-qp_b', '0' ] + assert set_video_quality('hevc_amf', 0) == [ '-qp_i', '51', '-qp_p', '51', '-qp_b', '51' ] + assert set_video_quality('hevc_amf', 50) == [ '-qp_i', '26', '-qp_p', '26', '-qp_b', '26' ] + assert set_video_quality('hevc_amf', 100) == [ '-qp_i', '0', '-qp_p', '0', '-qp_b', '0' ] + assert set_video_quality('h264_qsv', 0) == [ '-qp', '51' ] + assert set_video_quality('h264_qsv', 50) == [ '-qp', '26' ] + assert set_video_quality('h264_qsv', 100) == [ '-qp', '0' ] + assert set_video_quality('hevc_qsv', 0) == [ '-qp', '51' ] + assert set_video_quality('hevc_qsv', 50) == [ '-qp', '26' ] + assert set_video_quality('hevc_qsv', 100) == [ '-qp', '0' ] + assert set_video_quality('h264_videotoolbox', 0) == [ '-b:v', '1024k' ] + assert set_video_quality('h264_videotoolbox', 50) == [ '-b:v', '25768k' ] + assert set_video_quality('h264_videotoolbox', 100) == [ '-b:v', '50512k' ] + assert set_video_quality('hevc_videotoolbox', 0) == [ '-b:v', '1024k' ] + assert set_video_quality('hevc_videotoolbox', 50) == [ '-b:v', '25768k' ] + assert set_video_quality('hevc_videotoolbox', 100) == [ '-b:v', '50512k' ] diff --git a/tests/test_filesystem.py b/tests/test_filesystem.py new file mode 100644 index 0000000000000000000000000000000000000000..fc2ef245e5f851b93103fcda1f6b97f7a262e5bf --- /dev/null +++ b/tests/test_filesystem.py @@ -0,0 +1,135 @@ +import os.path + +import pytest + +from facefusion.download import conditional_download +from facefusion.filesystem import create_directory, filter_audio_paths, filter_image_paths, get_file_extension, get_file_format, get_file_size, has_audio, has_image, has_video, in_directory, is_audio, is_directory, is_file, is_image, is_video, remove_directory, resolve_file_paths, same_file_extension +from .helper import get_test_example_file, get_test_examples_directory, get_test_outputs_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.mp3', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + + +def test_get_file_size() -> None: + assert get_file_size(get_test_example_file('source.jpg')) == 549458 + assert get_file_size('invalid') == 0 + + +def test_get_file_extension() -> None: + assert get_file_extension('source.jpg') == '.jpg' + assert get_file_extension('source.mp3') == '.mp3' + assert get_file_extension('invalid') is None + + +def test_get_file_format() -> None: + assert get_file_format('source.jpg') == 'jpeg' + assert get_file_format('source.jpeg') == 'jpeg' + assert get_file_format('source.mp3') == 'mp3' + assert get_file_format('invalid') is None + + +def test_same_file_extension() -> None: + assert same_file_extension('source.jpg', 'source.jpg') is True + assert same_file_extension('source.jpg', 'source.mp3') is False + assert same_file_extension('invalid', 'invalid') is False + + +def test_is_file() -> None: + assert is_file(get_test_example_file('source.jpg')) is True + assert is_file(get_test_examples_directory()) is False + assert is_file('invalid') is False + + +def test_is_audio() -> None: + assert is_audio(get_test_example_file('source.mp3')) is True + assert is_audio(get_test_example_file('source.jpg')) is False + assert is_audio('invalid') is False + + +def test_has_audio() -> None: + assert has_audio([ get_test_example_file('source.mp3') ]) is True + assert has_audio([ get_test_example_file('source.mp3'), get_test_example_file('source.jpg') ]) is True + assert has_audio([ get_test_example_file('source.jpg'), get_test_example_file('source.jpg') ]) is False + assert has_audio([ 'invalid' ]) is False + + +def test_is_image() -> None: + assert is_image(get_test_example_file('source.jpg')) is True + assert is_image(get_test_example_file('target-240p.mp4')) is False + assert is_image('invalid') is False + + +def test_has_image() -> None: + assert has_image([ get_test_example_file('source.jpg') ]) is True + assert has_image([ get_test_example_file('source.jpg'), get_test_example_file('source.mp3') ]) is True + assert has_image([ get_test_example_file('source.mp3'), get_test_example_file('source.mp3') ]) is False + assert has_image([ 'invalid' ]) is False + + +def test_is_video() -> None: + assert is_video(get_test_example_file('target-240p.mp4')) is True + assert is_video(get_test_example_file('source.jpg')) is False + assert is_video('invalid') is False + + +def test_has_video() -> None: + assert has_video([ get_test_example_file('target-240p.mp4') ]) is True + assert has_video([ get_test_example_file('target-240p.mp4'), get_test_example_file('source.mp3') ]) is True + assert has_video([ get_test_example_file('source.mp3'), get_test_example_file('source.mp3') ]) is False + assert has_video([ 'invalid' ]) is False + + +def test_filter_audio_paths() -> None: + assert filter_audio_paths([ get_test_example_file('source.jpg'), get_test_example_file('source.mp3') ]) == [ get_test_example_file('source.mp3') ] + assert filter_audio_paths([ get_test_example_file('source.jpg'), get_test_example_file('source.jpg') ]) == [] + assert filter_audio_paths([ 'invalid' ]) == [] + + +def test_filter_image_paths() -> None: + assert filter_image_paths([ get_test_example_file('source.jpg'), get_test_example_file('source.mp3') ]) == [ get_test_example_file('source.jpg') ] + assert filter_image_paths([ get_test_example_file('source.mp3'), get_test_example_file('source.mp3') ]) == [] + assert filter_audio_paths([ 'invalid' ]) == [] + + +def test_resolve_file_paths() -> None: + file_paths = resolve_file_paths(get_test_examples_directory()) + + for file_path in file_paths: + assert file_path == get_test_example_file(file_path) + + assert resolve_file_paths('invalid') == [] + + +def test_create_directory() -> None: + create_directory_path = os.path.join(get_test_outputs_directory(), 'create_directory') + + assert create_directory(create_directory_path) is True + assert create_directory(get_test_example_file('source.jpg')) is False + + +def test_remove_directory() -> None: + remove_directory_path = os.path.join(get_test_outputs_directory(), 'remove_directory') + create_directory(remove_directory_path) + + assert remove_directory(remove_directory_path) is True + assert remove_directory(get_test_example_file('source.jpg')) is False + assert remove_directory('invalid') is False + + +def test_is_directory() -> None: + assert is_directory(get_test_examples_directory()) is True + assert is_directory(get_test_example_file('source.jpg')) is False + assert is_directory('invalid') is False + + +def test_in_directory() -> None: + assert in_directory(get_test_example_file('source.jpg')) is True + assert in_directory('source.jpg') is False + assert in_directory('invalid') is False diff --git a/tests/test_inference_manager.py b/tests/test_inference_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4d61e59898b1bea550dde674c531227c670d2c --- /dev/null +++ b/tests/test_inference_manager.py @@ -0,0 +1,32 @@ +from unittest.mock import patch + +import pytest +from onnxruntime import InferenceSession + +from facefusion import content_analyser, state_manager +from facefusion.inference_manager import INFERENCE_POOL_SET, get_inference_pool + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + state_manager.init_item('execution_device_id', '0') + state_manager.init_item('execution_providers', [ 'cpu' ]) + state_manager.init_item('download_providers', [ 'github' ]) + content_analyser.pre_check() + + +def test_get_inference_pool() -> None: + model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ] + _, model_source_set = content_analyser.collect_model_downloads() + + with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): + get_inference_pool('facefusion.content_analyser', model_names, model_source_set) + + assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.nsfw_1.nsfw_2.nsfw_3.0.cpu').get('nsfw_1'), InferenceSession) + + with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): + get_inference_pool('facefusion.content_analyser', model_names, model_source_set) + + assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.nsfw_1.nsfw_2.nsfw_3.0.cpu').get('nsfw_1'), InferenceSession) + + assert INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.nsfw_1.nsfw_2.nsfw_3.0.cpu').get('nsfw_1') == INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.nsfw_1.nsfw_2.nsfw_3.0.cpu').get('nsfw_1') diff --git a/tests/test_job_helper.py b/tests/test_job_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..08fe6f8620abb43fb7bb8e09cf2741bb2e91d7c1 --- /dev/null +++ b/tests/test_job_helper.py @@ -0,0 +1,8 @@ +import os + +from facefusion.jobs.job_helper import get_step_output_path + + +def test_get_step_output_path() -> None: + assert get_step_output_path('test-job', 0, 'test.mp4') == 'test-test-job-0.mp4' + assert get_step_output_path('test-job', 0, 'test/test.mp4') == os.path.join('test', 'test-test-job-0.mp4') diff --git a/tests/test_job_list.py b/tests/test_job_list.py new file mode 100644 index 0000000000000000000000000000000000000000..732a199f6586e76ea0e93c8b6639d7a2c530a237 --- /dev/null +++ b/tests/test_job_list.py @@ -0,0 +1,24 @@ +from time import sleep + +import pytest + +from facefusion.jobs.job_list import compose_job_list +from facefusion.jobs.job_manager import clear_jobs, create_job, init_jobs +from .helper import get_test_jobs_directory + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + + +def test_compose_job_list() -> None: + create_job('job-test-compose-job-list-1') + sleep(0.5) + create_job('job-test-compose-job-list-2') + job_headers, job_contents = compose_job_list('drafted') + + assert job_headers == [ 'job id', 'steps', 'date created', 'date updated', 'job status' ] + assert job_contents[0] == [ 'job-test-compose-job-list-1', 0, 'just now', None, 'drafted' ] + assert job_contents[1] == [ 'job-test-compose-job-list-2', 0, 'just now', None, 'drafted' ] diff --git a/tests/test_job_manager.py b/tests/test_job_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..05b1fedcd7fce6681bfc3e918c6cacb883ec98a0 --- /dev/null +++ b/tests/test_job_manager.py @@ -0,0 +1,386 @@ +from time import sleep + +import pytest + +from facefusion.jobs.job_helper import get_step_output_path +from facefusion.jobs.job_manager import add_step, clear_jobs, count_step_total, create_job, delete_job, delete_jobs, find_job_ids, find_jobs, get_steps, init_jobs, insert_step, move_job_file, remix_step, remove_step, set_step_status, set_steps_status, submit_job, submit_jobs +from .helper import get_test_jobs_directory + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + + +def test_create_job() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + + assert create_job('job-test-create-job') is True + assert create_job('job-test-create-job') is False + + add_step('job-test-submit-job', args_1) + submit_job('job-test-create-job') + + assert create_job('job-test-create-job') is False + + +def test_submit_job() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + + assert submit_job('job-invalid') is False + + create_job('job-test-submit-job') + + assert submit_job('job-test-submit-job') is False + + add_step('job-test-submit-job', args_1) + + assert submit_job('job-test-submit-job') is True + assert submit_job('job-test-submit-job') is False + + +def test_submit_jobs() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + args_2 =\ + { + 'source_path': 'source-2.jpg', + 'target_path': 'target-2.jpg', + 'output_path': 'output-2.jpg' + } + halt_on_error = True + + assert submit_jobs(halt_on_error) is False + + create_job('job-test-submit-jobs-1') + create_job('job-test-submit-jobs-2') + + assert submit_jobs(halt_on_error) is False + + add_step('job-test-submit-jobs-1', args_1) + add_step('job-test-submit-jobs-2', args_2) + + assert submit_jobs(halt_on_error) is True + assert submit_jobs(halt_on_error) is False + + +def test_delete_job() -> None: + assert delete_job('job-invalid') is False + + create_job('job-test-delete-job') + + assert delete_job('job-test-delete-job') is True + assert delete_job('job-test-delete-job') is False + + +def test_delete_jobs() -> None: + halt_on_error = True + + assert delete_jobs(halt_on_error) is False + + create_job('job-test-delete-jobs-1') + create_job('job-test-delete-jobs-2') + + assert delete_jobs(halt_on_error) is True + + +def test_find_jobs() -> None: + create_job('job-test-find-jobs-1') + sleep(0.5) + create_job('job-test-find-jobs-2') + + assert 'job-test-find-jobs-1' in find_jobs('drafted') + assert 'job-test-find-jobs-2' in find_jobs('drafted') + assert not find_jobs('queued') + + move_job_file('job-test-find-jobs-1', 'queued') + + assert 'job-test-find-jobs-2' in find_jobs('drafted') + assert 'job-test-find-jobs-1' in find_jobs('queued') + + +def test_find_job_ids() -> None: + create_job('job-test-find-job-ids-1') + sleep(0.5) + create_job('job-test-find-job-ids-2') + sleep(0.5) + create_job('job-test-find-job-ids-3') + + assert find_job_ids('drafted') == [ 'job-test-find-job-ids-1', 'job-test-find-job-ids-2', 'job-test-find-job-ids-3' ] + assert find_job_ids('queued') == [] + assert find_job_ids('completed') == [] + assert find_job_ids('failed') == [] + + move_job_file('job-test-find-job-ids-1', 'queued') + move_job_file('job-test-find-job-ids-2', 'queued') + move_job_file('job-test-find-job-ids-3', 'queued') + + assert find_job_ids('drafted') == [] + assert find_job_ids('queued') == [ 'job-test-find-job-ids-1', 'job-test-find-job-ids-2', 'job-test-find-job-ids-3' ] + assert find_job_ids('completed') == [] + assert find_job_ids('failed') == [] + + move_job_file('job-test-find-job-ids-1', 'completed') + + assert find_job_ids('drafted') == [] + assert find_job_ids('queued') == [ 'job-test-find-job-ids-2', 'job-test-find-job-ids-3' ] + assert find_job_ids('completed') == [ 'job-test-find-job-ids-1' ] + assert find_job_ids('failed') == [] + + move_job_file('job-test-find-job-ids-2', 'failed') + + assert find_job_ids('drafted') == [] + assert find_job_ids('queued') == [ 'job-test-find-job-ids-3' ] + assert find_job_ids('completed') == [ 'job-test-find-job-ids-1' ] + assert find_job_ids('failed') == [ 'job-test-find-job-ids-2' ] + + move_job_file('job-test-find-job-ids-3', 'completed') + + assert find_job_ids('drafted') == [] + assert find_job_ids('queued') == [] + assert find_job_ids('completed') == [ 'job-test-find-job-ids-1', 'job-test-find-job-ids-3' ] + assert find_job_ids('failed') == [ 'job-test-find-job-ids-2' ] + + +def test_add_step() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + args_2 =\ + { + 'source_path': 'source-2.jpg', + 'target_path': 'target-2.jpg', + 'output_path': 'output-2.jpg' + } + + assert add_step('job-invalid', args_1) is False + + create_job('job-test-add-step') + + assert add_step('job-test-add-step', args_1) is True + assert add_step('job-test-add-step', args_2) is True + + steps = get_steps('job-test-add-step') + + assert steps[0].get('args') == args_1 + assert steps[1].get('args') == args_2 + assert count_step_total('job-test-add-step') == 2 + + +def test_remix_step() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + args_2 =\ + { + 'source_path': 'source-2.jpg', + 'target_path': 'target-2.jpg', + 'output_path': 'output-2.jpg' + } + + assert remix_step('job-invalid', 0, args_1) is False + + create_job('job-test-remix-step') + add_step('job-test-remix-step', args_1) + add_step('job-test-remix-step', args_2) + + assert remix_step('job-test-remix-step', 99, args_1) is False + assert remix_step('job-test-remix-step', 0, args_2) is True + assert remix_step('job-test-remix-step', -1, args_2) is True + + steps = get_steps('job-test-remix-step') + + assert steps[0].get('args') == args_1 + assert steps[1].get('args') == args_2 + assert steps[2].get('args').get('source_path') == args_2.get('source_path') + assert steps[2].get('args').get('target_path') == get_step_output_path('job-test-remix-step', 0, args_1.get('output_path')) + assert steps[2].get('args').get('output_path') == args_2.get('output_path') + assert steps[3].get('args').get('source_path') == args_2.get('source_path') + assert steps[3].get('args').get('target_path') == get_step_output_path('job-test-remix-step', 2, args_2.get('output_path')) + assert steps[3].get('args').get('output_path') == args_2.get('output_path') + assert count_step_total('job-test-remix-step') == 4 + + +def test_insert_step() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + args_2 =\ + { + 'source_path': 'source-2.jpg', + 'target_path': 'target-2.jpg', + 'output_path': 'output-2.jpg' + } + args_3 =\ + { + 'source_path': 'source-3.jpg', + 'target_path': 'target-3.jpg', + 'output_path': 'output-3.jpg' + } + + assert insert_step('job-invalid', 0, args_1) is False + + create_job('job-test-insert-step') + add_step('job-test-insert-step', args_1) + add_step('job-test-insert-step', args_1) + + assert insert_step('job-test-insert-step', 99, args_1) is False + assert insert_step('job-test-insert-step', 0, args_2) is True + assert insert_step('job-test-insert-step', -1, args_3) is True + + steps = get_steps('job-test-insert-step') + + assert steps[0].get('args') == args_2 + assert steps[1].get('args') == args_1 + assert steps[2].get('args') == args_3 + assert steps[3].get('args') == args_1 + assert count_step_total('job-test-insert-step') == 4 + + +def test_remove_step() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + args_2 =\ + { + 'source_path': 'source-2.jpg', + 'target_path': 'target-2.jpg', + 'output_path': 'output-2.jpg' + } + args_3 =\ + { + 'source_path': 'source-3.jpg', + 'target_path': 'target-3.jpg', + 'output_path': 'output-3.jpg' + } + + assert remove_step('job-invalid', 0) is False + + create_job('job-test-remove-step') + add_step('job-test-remove-step', args_1) + add_step('job-test-remove-step', args_2) + add_step('job-test-remove-step', args_1) + add_step('job-test-remove-step', args_3) + + assert remove_step('job-test-remove-step', 99) is False + assert remove_step('job-test-remove-step', 0) is True + assert remove_step('job-test-remove-step', -1) is True + + steps = get_steps('job-test-remove-step') + + assert steps[0].get('args') == args_2 + assert steps[1].get('args') == args_1 + assert count_step_total('job-test-remove-step') == 2 + + +def test_get_steps() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + args_2 =\ + { + 'source_path': 'source-2.jpg', + 'target_path': 'target-2.jpg', + 'output_path': 'output-2.jpg' + } + + assert get_steps('job-invalid') == [] + + create_job('job-test-get-steps') + add_step('job-test-get-steps', args_1) + add_step('job-test-get-steps', args_2) + steps = get_steps('job-test-get-steps') + + assert steps[0].get('args') == args_1 + assert steps[1].get('args') == args_2 + assert count_step_total('job-test-get-steps') == 2 + + +def test_set_step_status() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + args_2 =\ + { + 'source_path': 'source-2.jpg', + 'target_path': 'target-2.jpg', + 'output_path': 'output-2.jpg' + } + + assert set_step_status('job-invalid', 0, 'completed') is False + + create_job('job-test-set-step-status') + add_step('job-test-set-step-status', args_1) + add_step('job-test-set-step-status', args_2) + + assert set_step_status('job-test-set-step-status', 99, 'completed') is False + assert set_step_status('job-test-set-step-status', 0, 'completed') is True + assert set_step_status('job-test-set-step-status', 1, 'failed') is True + + steps = get_steps('job-test-set-step-status') + + assert steps[0].get('status') == 'completed' + assert steps[1].get('status') == 'failed' + assert count_step_total('job-test-set-step-status') == 2 + + +def test_set_steps_status() -> None: + args_1 =\ + { + 'source_path': 'source-1.jpg', + 'target_path': 'target-1.jpg', + 'output_path': 'output-1.jpg' + } + args_2 =\ + { + 'source_path': 'source-2.jpg', + 'target_path': 'target-2.jpg', + 'output_path': 'output-2.jpg' + } + + assert set_steps_status('job-invalid', 'queued') is False + + create_job('job-test-set-steps-status') + add_step('job-test-set-steps-status', args_1) + add_step('job-test-set-steps-status', args_2) + + assert set_steps_status('job-test-set-steps-status', 'queued') is True + + steps = get_steps('job-test-set-steps-status') + + assert steps[0].get('status') == 'queued' + assert steps[1].get('status') == 'queued' + assert count_step_total('job-test-set-steps-status') == 2 diff --git a/tests/test_job_runner.py b/tests/test_job_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..4e0983ae1d125bf9b08d1e2f9d4833b078c0a975 --- /dev/null +++ b/tests/test_job_runner.py @@ -0,0 +1,276 @@ +import subprocess + +import pytest + +from facefusion.download import conditional_download +from facefusion.filesystem import copy_file +from facefusion.jobs.job_manager import add_step, clear_jobs, create_job, init_jobs, move_job_file, submit_job, submit_jobs +from facefusion.jobs.job_runner import collect_output_set, finalize_steps, retry_job, retry_jobs, run_job, run_jobs, run_steps +from facefusion.types import Args +from .helper import get_test_example_file, get_test_examples_directory, get_test_jobs_directory, get_test_output_file, is_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_jobs(get_test_jobs_directory()) + init_jobs(get_test_jobs_directory()) + prepare_test_output_directory() + + +def process_step(job_id : str, step_index : int, step_args : Args) -> bool: + return copy_file(step_args.get('target_path'), step_args.get('output_path')) + + +def test_run_job() -> None: + args_1 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-1.mp4') + } + args_2 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-2.mp4') + } + args_3 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.jpg'), + 'output_path': get_test_output_file('output-3.jpg') + } + + assert run_job('job-invalid', process_step) is False + + create_job('job-test-run-job') + add_step('job-test-run-job', args_1) + add_step('job-test-run-job', args_2) + add_step('job-test-run-job', args_2) + add_step('job-test-run-job', args_3) + + assert run_job('job-test-run-job', process_step) is False + + submit_job('job-test-run-job') + + assert run_job('job-test-run-job', process_step) is True + + +def test_run_jobs() -> None: + args_1 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-1.mp4') + } + args_2 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-2.mp4') + } + args_3 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.jpg'), + 'output_path': get_test_output_file('output-3.jpg') + } + halt_on_error = True + + assert run_jobs(process_step, halt_on_error) is False + + create_job('job-test-run-jobs-1') + create_job('job-test-run-jobs-2') + add_step('job-test-run-jobs-1', args_1) + add_step('job-test-run-jobs-1', args_1) + add_step('job-test-run-jobs-2', args_2) + add_step('job-test-run-jobs-3', args_3) + + assert run_jobs(process_step, halt_on_error) is False + + submit_jobs(halt_on_error) + + assert run_jobs(process_step, halt_on_error) is True + + +def test_retry_job() -> None: + args_1 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-1.mp4') + } + + assert retry_job('job-invalid', process_step) is False + + create_job('job-test-retry-job') + add_step('job-test-retry-job', args_1) + submit_job('job-test-retry-job') + + assert retry_job('job-test-retry-job', process_step) is False + + move_job_file('job-test-retry-job', 'failed') + + assert retry_job('job-test-retry-job', process_step) is True + + +def test_retry_jobs() -> None: + args_1 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-1.mp4') + } + args_2 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-2.mp4') + } + args_3 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.jpg'), + 'output_path': get_test_output_file('output-3.jpg') + } + halt_on_error = True + + assert retry_jobs(process_step, halt_on_error) is False + + create_job('job-test-retry-jobs-1') + create_job('job-test-retry-jobs-2') + add_step('job-test-retry-jobs-1', args_1) + add_step('job-test-retry-jobs-1', args_1) + add_step('job-test-retry-jobs-2', args_2) + add_step('job-test-retry-jobs-3', args_3) + + assert retry_jobs(process_step, halt_on_error) is False + + move_job_file('job-test-retry-jobs-1', 'failed') + move_job_file('job-test-retry-jobs-2', 'failed') + + assert retry_jobs(process_step, halt_on_error) is True + + +def test_run_steps() -> None: + args_1 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-1.mp4') + } + args_2 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-2.mp4') + } + args_3 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.jpg'), + 'output_path': get_test_output_file('output-3.jpg') + } + + assert run_steps('job-invalid', process_step) is False + + create_job('job-test-run-steps') + add_step('job-test-run-steps', args_1) + add_step('job-test-run-steps', args_1) + add_step('job-test-run-steps', args_2) + add_step('job-test-run-steps', args_3) + + assert run_steps('job-test-run-steps', process_step) is True + + +def test_finalize_steps() -> None: + args_1 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-1.mp4') + } + args_2 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-2.mp4') + } + args_3 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.jpg'), + 'output_path': get_test_output_file('output-3.jpg') + } + + create_job('job-test-finalize-steps') + add_step('job-test-finalize-steps', args_1) + add_step('job-test-finalize-steps', args_1) + add_step('job-test-finalize-steps', args_2) + add_step('job-test-finalize-steps', args_3) + + copy_file(args_1.get('target_path'), get_test_output_file('output-1-job-test-finalize-steps-0.mp4')) + copy_file(args_1.get('target_path'), get_test_output_file('output-1-job-test-finalize-steps-1.mp4')) + copy_file(args_2.get('target_path'), get_test_output_file('output-2-job-test-finalize-steps-2.mp4')) + copy_file(args_3.get('target_path'), get_test_output_file('output-3-job-test-finalize-steps-3.jpg')) + + assert finalize_steps('job-test-finalize-steps') is True + assert is_test_output_file('output-1.mp4') is True + assert is_test_output_file('output-2.mp4') is True + assert is_test_output_file('output-3.jpg') is True + + +def test_collect_output_set() -> None: + args_1 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-1.mp4') + } + args_2 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.mp4'), + 'output_path': get_test_output_file('output-2.mp4') + } + args_3 =\ + { + 'source_path': get_test_example_file('source.jpg'), + 'target_path': get_test_example_file('target-240p.jpg'), + 'output_path': get_test_output_file('output-3.jpg') + } + + create_job('job-test-collect-output-set') + add_step('job-test-collect-output-set', args_1) + add_step('job-test-collect-output-set', args_1) + add_step('job-test-collect-output-set', args_2) + add_step('job-test-collect-output-set', args_3) + + output_set =\ + { + get_test_output_file('output-1.mp4'): + [ + get_test_output_file('output-1-job-test-collect-output-set-0.mp4'), + get_test_output_file('output-1-job-test-collect-output-set-1.mp4') + ], + get_test_output_file('output-2.mp4'): + [ + get_test_output_file('output-2-job-test-collect-output-set-2.mp4') + ], + get_test_output_file('output-3.jpg'): + [ + get_test_output_file('output-3-job-test-collect-output-set-3.jpg') + ] + } + + assert collect_output_set('job-test-collect-output-set') == output_set diff --git a/tests/test_json.py b/tests/test_json.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d8a387456a2520189beb0cc6056584b087a099 --- /dev/null +++ b/tests/test_json.py @@ -0,0 +1,19 @@ +import tempfile + +from facefusion.json import read_json, write_json + + +def test_read_json() -> None: + _, json_path = tempfile.mkstemp(suffix = '.json') + + assert not read_json(json_path) + + write_json(json_path, {}) + + assert read_json(json_path) == {} + + +def test_write_json() -> None: + _, json_path = tempfile.mkstemp(suffix = '.json') + + assert write_json(json_path, {}) diff --git a/tests/test_memory.py b/tests/test_memory.py new file mode 100644 index 0000000000000000000000000000000000000000..e637ea10e97b9fedf01c3fff8540b3a74701aca1 --- /dev/null +++ b/tests/test_memory.py @@ -0,0 +1,8 @@ +from facefusion.common_helper import is_linux, is_macos +from facefusion.memory import limit_system_memory + + +def test_limit_system_memory() -> None: + assert limit_system_memory(4) is True + if is_linux() or is_macos(): + assert limit_system_memory(1024) is False diff --git a/tests/test_normalizer.py b/tests/test_normalizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0673f64ffb72cae74540bae10cdcbbb2a201d4fb --- /dev/null +++ b/tests/test_normalizer.py @@ -0,0 +1,16 @@ +from facefusion.normalizer import normalize_fps, normalize_padding + + +def test_normalize_padding() -> None: + assert normalize_padding([ 0, 0, 0, 0 ]) == (0, 0, 0, 0) + assert normalize_padding([ 1 ]) == (1, 1, 1, 1) + assert normalize_padding([ 1, 2 ]) == (1, 2, 1, 2) + assert normalize_padding([ 1, 2, 3 ]) == (1, 2, 3, 2) + assert normalize_padding(None) is None + + +def test_normalize_fps() -> None: + assert normalize_fps(0.0) == 1.0 + assert normalize_fps(25.0) == 25.0 + assert normalize_fps(61.0) == 60.0 + assert normalize_fps(None) is None diff --git a/tests/test_process_manager.py b/tests/test_process_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..85e64645937fa85539dad2c97bf54d516e41ae29 --- /dev/null +++ b/tests/test_process_manager.py @@ -0,0 +1,22 @@ +from facefusion.process_manager import end, is_pending, is_processing, is_stopping, set_process_state, start, stop + + +def test_start() -> None: + set_process_state('pending') + start() + + assert is_processing() + + +def test_stop() -> None: + set_process_state('processing') + stop() + + assert is_stopping() + + +def test_end() -> None: + set_process_state('processing') + end() + + assert is_pending() diff --git a/tests/test_program_helper.py b/tests/test_program_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..92b64fb29e57d3877bf3c1bab9071e75ab96ba2c --- /dev/null +++ b/tests/test_program_helper.py @@ -0,0 +1,40 @@ +from argparse import ArgumentParser + +import pytest + +from facefusion.program_helper import find_argument_group, validate_actions + + +def test_find_argument_group() -> None: + program = ArgumentParser() + program.add_argument_group('test-1') + program.add_argument_group('test-2') + + assert find_argument_group(program, 'test-1') + assert find_argument_group(program, 'test-2') + assert find_argument_group(program, 'invalid') is None + + +@pytest.mark.skip() +def test_validate_args() -> None: + pass + + +def test_validate_actions() -> None: + program = ArgumentParser() + program.add_argument('--test-1', default = 'test_1', choices = [ 'test_1', 'test_2' ]) + program.add_argument('--test-2', default = 'test_2', choices= [ 'test_1', 'test_2' ], nargs = '+') + + assert validate_actions(program) is True + + args =\ + { + 'test_1': 'test_2', + 'test_2': [ 'test_1', 'test_3' ] + } + + for action in program._actions: + if action.dest in args: + action.default = args[action.dest] + + assert validate_actions(program) is False diff --git a/tests/test_state_manager.py b/tests/test_state_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5d3da066db35280b0b95c0796097d6a78426ed5f --- /dev/null +++ b/tests/test_state_manager.py @@ -0,0 +1,35 @@ +from typing import Union + +import pytest + +from facefusion.processors.types import ProcessorState +from facefusion.state_manager import STATE_SET, get_item, init_item, set_item +from facefusion.types import AppContext, State + + +def get_state(app_context : AppContext) -> Union[State, ProcessorState]: + return STATE_SET.get(app_context) + + +def clear_state(app_context : AppContext) -> None: + STATE_SET[app_context] = {} #type:ignore[typeddict-item] + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + clear_state('cli') + clear_state('ui') + + +def test_init_item() -> None: + init_item('video_memory_strategy', 'tolerant') + + assert get_state('cli').get('video_memory_strategy') == 'tolerant' + assert get_state('ui').get('video_memory_strategy') == 'tolerant' + + +def test_get_item_and_set_item() -> None: + set_item('video_memory_strategy', 'tolerant') + + assert get_item('video_memory_strategy') == 'tolerant' + assert get_state('ui').get('video_memory_strategy') is None diff --git a/tests/test_temp_helper.py b/tests/test_temp_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..6903d2ca737040425ec6470ee12e28e91568e3e8 --- /dev/null +++ b/tests/test_temp_helper.py @@ -0,0 +1,34 @@ +import os.path +import tempfile + +import pytest + +from facefusion import state_manager +from facefusion.download import conditional_download +from facefusion.temp_helper import get_temp_directory_path, get_temp_file_path, get_temp_frames_pattern +from .helper import get_test_example_file, get_test_examples_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4' + ]) + state_manager.init_item('temp_path', tempfile.gettempdir()) + state_manager.init_item('temp_frame_format', 'png') + + +def test_get_temp_file_path() -> None: + temp_directory = tempfile.gettempdir() + assert get_temp_file_path(get_test_example_file('target-240p.mp4')) == os.path.join(temp_directory, 'facefusion', 'target-240p', 'temp.mp4') + + +def test_get_temp_directory_path() -> None: + temp_directory = tempfile.gettempdir() + assert get_temp_directory_path(get_test_example_file('target-240p.mp4')) == os.path.join(temp_directory, 'facefusion', 'target-240p') + + +def test_get_temp_frames_pattern() -> None: + temp_directory = tempfile.gettempdir() + assert get_temp_frames_pattern(get_test_example_file('target-240p.mp4'), '%04d') == os.path.join(temp_directory, 'facefusion', 'target-240p', '%04d.png') diff --git a/tests/test_vision.py b/tests/test_vision.py new file mode 100644 index 0000000000000000000000000000000000000000..01463a431ed8b7076d7d792439bea91a2da72df9 --- /dev/null +++ b/tests/test_vision.py @@ -0,0 +1,179 @@ +import subprocess + +import pytest + +from facefusion.download import conditional_download +from facefusion.vision import calc_histogram_difference, count_trim_frame_total, count_video_frame_total, create_image_resolutions, create_video_resolutions, detect_image_resolution, detect_video_duration, detect_video_fps, detect_video_resolution, match_frame_color, normalize_resolution, pack_resolution, predict_video_frame_total, read_image, read_video_frame, restrict_image_resolution, restrict_trim_frame, restrict_video_fps, restrict_video_resolution, unpack_resolution, write_image +from .helper import get_test_example_file, get_test_examples_directory, get_test_output_file, prepare_test_output_directory + + +@pytest.fixture(scope = 'module', autouse = True) +def before_all() -> None: + conditional_download(get_test_examples_directory(), + [ + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/source.jpg', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-240p.mp4', + 'https://github.com/facefusion/facefusion-assets/releases/download/examples-3.0.0/target-1080p.mp4' + ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('target-240p.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', get_test_example_file('目标-240p.webp') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-1080p.mp4'), '-vframes', '1', get_test_example_file('target-1080p.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', '-vf', 'hue=s=0', get_test_example_file('target-240p-0sat.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vframes', '1', '-vf', 'transpose=0', get_test_example_file('target-240p-90deg.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-1080p.mp4'), '-vframes', '1', '-vf', 'transpose=0', get_test_example_file('target-1080p-90deg.jpg') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vf', 'fps=25', get_test_example_file('target-240p-25fps.mp4') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vf', 'fps=30', get_test_example_file('target-240p-30fps.mp4') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vf', 'fps=60', get_test_example_file('target-240p-60fps.mp4') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-240p.mp4'), '-vf', 'transpose=0', get_test_example_file('target-240p-90deg.mp4') ]) + subprocess.run([ 'ffmpeg', '-i', get_test_example_file('target-1080p.mp4'), '-vf', 'transpose=0', get_test_example_file('target-1080p-90deg.mp4') ]) + + +@pytest.fixture(scope = 'function', autouse = True) +def before_each() -> None: + prepare_test_output_directory() + + +def test_read_image() -> None: + assert read_image(get_test_example_file('target-240p.jpg')).shape == (226, 426, 3) + assert read_image(get_test_example_file('目标-240p.webp')).shape == (226, 426, 3) + assert read_image('invalid') is None + + +def test_write_image() -> None: + vision_frame = read_image(get_test_example_file('target-240p.jpg')) + + assert write_image(get_test_output_file('target-240p.jpg'), vision_frame) is True + assert write_image(get_test_output_file('目标-240p.webp'), vision_frame) is True + + +def test_detect_image_resolution() -> None: + assert detect_image_resolution(get_test_example_file('target-240p.jpg')) == (426, 226) + assert detect_image_resolution(get_test_example_file('target-240p-90deg.jpg')) == (226, 426) + assert detect_image_resolution(get_test_example_file('target-1080p.jpg')) == (2048, 1080) + assert detect_image_resolution(get_test_example_file('target-1080p-90deg.jpg')) == (1080, 2048) + assert detect_image_resolution('invalid') is None + + +def test_restrict_image_resolution() -> None: + assert restrict_image_resolution(get_test_example_file('target-1080p.jpg'), (426, 226)) == (426, 226) + assert restrict_image_resolution(get_test_example_file('target-1080p.jpg'), (2048, 1080)) == (2048, 1080) + assert restrict_image_resolution(get_test_example_file('target-1080p.jpg'), (4096, 2160)) == (2048, 1080) + + +def test_create_image_resolutions() -> None: + assert create_image_resolutions((426, 226)) == [ '106x56', '212x112', '320x170', '426x226', '640x340', '852x452', '1064x564', '1278x678', '1492x792', '1704x904' ] + assert create_image_resolutions((226, 426)) == [ '56x106', '112x212', '170x320', '226x426', '340x640', '452x852', '564x1064', '678x1278', '792x1492', '904x1704' ] + assert create_image_resolutions((2048, 1080)) == [ '512x270', '1024x540', '1536x810', '2048x1080', '3072x1620', '4096x2160', '5120x2700', '6144x3240', '7168x3780', '8192x4320' ] + assert create_image_resolutions((1080, 2048)) == [ '270x512', '540x1024', '810x1536', '1080x2048', '1620x3072', '2160x4096', '2700x5120', '3240x6144', '3780x7168', '4320x8192' ] + assert create_image_resolutions(None) == [] + + +def test_read_video_frame() -> None: + assert hasattr(read_video_frame(get_test_example_file('target-240p-25fps.mp4')), '__array_interface__') + assert read_video_frame('invalid') is None + + +def test_count_video_frame_total() -> None: + assert count_video_frame_total(get_test_example_file('target-240p-25fps.mp4')) == 270 + assert count_video_frame_total(get_test_example_file('target-240p-30fps.mp4')) == 324 + assert count_video_frame_total(get_test_example_file('target-240p-60fps.mp4')) == 648 + assert count_video_frame_total('invalid') == 0 + + +def test_predict_video_frame_total() -> None: + assert predict_video_frame_total(get_test_example_file('target-240p-25fps.mp4'), 12.5, 0, 100) == 50 + assert predict_video_frame_total(get_test_example_file('target-240p-25fps.mp4'), 25, 0, 100) == 100 + assert predict_video_frame_total(get_test_example_file('target-240p-25fps.mp4'), 25, 0, 200) == 200 + assert predict_video_frame_total('invalid', 25, 0, 100) == 0 + + +def test_detect_video_fps() -> None: + assert detect_video_fps(get_test_example_file('target-240p-25fps.mp4')) == 25.0 + assert detect_video_fps(get_test_example_file('target-240p-30fps.mp4')) == 30.0 + assert detect_video_fps(get_test_example_file('target-240p-60fps.mp4')) == 60.0 + assert detect_video_fps('invalid') is None + + +def test_restrict_video_fps() -> None: + assert restrict_video_fps(get_test_example_file('target-1080p.mp4'), 20.0) == 20.0 + assert restrict_video_fps(get_test_example_file('target-1080p.mp4'), 25.0) == 25.0 + assert restrict_video_fps(get_test_example_file('target-1080p.mp4'), 60.0) == 25.0 + + +def test_detect_video_duration() -> None: + assert detect_video_duration(get_test_example_file('target-240p.mp4')) == 10.8 + assert detect_video_duration('invalid') == 0 + + +def test_count_trim_frame_total() -> None: + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 0, 200) == 200 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 70, 270) == 200 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), -10, None) == 270 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, -10) == 0 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), 280, None) == 0 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, 280) == 270 + assert count_trim_frame_total(get_test_example_file('target-240p.mp4'), None, None) == 270 + + +def test_restrict_trim_frame() -> None: + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 0, 200) == (0, 200) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 70, 270) == (70, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), -10, None) == (0, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, -10) == (0, 0) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), 280, None) == (270, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, 280) == (0, 270) + assert restrict_trim_frame(get_test_example_file('target-240p.mp4'), None, None) == (0, 270) + + +def test_detect_video_resolution() -> None: + assert detect_video_resolution(get_test_example_file('target-240p.mp4')) == (426, 226) + assert detect_video_resolution(get_test_example_file('target-240p-90deg.mp4')) == (226, 426) + assert detect_video_resolution(get_test_example_file('target-1080p.mp4')) == (2048, 1080) + assert detect_video_resolution(get_test_example_file('target-1080p-90deg.mp4')) == (1080, 2048) + assert detect_video_resolution('invalid') is None + + +def test_restrict_video_resolution() -> None: + assert restrict_video_resolution(get_test_example_file('target-1080p.mp4'), (426, 226)) == (426, 226) + assert restrict_video_resolution(get_test_example_file('target-1080p.mp4'), (2048, 1080)) == (2048, 1080) + assert restrict_video_resolution(get_test_example_file('target-1080p.mp4'), (4096, 2160)) == (2048, 1080) + + +def test_create_video_resolutions() -> None: + assert create_video_resolutions((426, 226)) == [ '426x226', '452x240', '678x360', '904x480', '1018x540', '1358x720', '2036x1080', '2714x1440', '4072x2160', '8144x4320' ] + assert create_video_resolutions((226, 426)) == [ '226x426', '240x452', '360x678', '480x904', '540x1018', '720x1358', '1080x2036', '1440x2714', '2160x4072', '4320x8144' ] + assert create_video_resolutions((2048, 1080)) == [ '456x240', '682x360', '910x480', '1024x540', '1366x720', '2048x1080', '2730x1440', '4096x2160', '8192x4320' ] + assert create_video_resolutions((1080, 2048)) == [ '240x456', '360x682', '480x910', '540x1024', '720x1366', '1080x2048', '1440x2730', '2160x4096', '4320x8192' ] + assert create_video_resolutions(None) == [] + + +def test_normalize_resolution() -> None: + assert normalize_resolution((2.5, 2.5)) == (2, 2) + assert normalize_resolution((3.0, 3.0)) == (4, 4) + assert normalize_resolution((6.5, 6.5)) == (6, 6) + + +def test_pack_resolution() -> None: + assert pack_resolution((1, 1)) == '0x0' + assert pack_resolution((2, 2)) == '2x2' + + +def test_unpack_resolution() -> None: + assert unpack_resolution('0x0') == (0, 0) + assert unpack_resolution('2x2') == (2, 2) + + +def test_calc_histogram_difference() -> None: + source_vision_frame = read_image(get_test_example_file('target-240p.jpg')) + target_vision_frame = read_image(get_test_example_file('target-240p-0sat.jpg')) + + assert calc_histogram_difference(source_vision_frame, source_vision_frame) == 1.0 + assert calc_histogram_difference(source_vision_frame, target_vision_frame) < 0.5 + + +def test_match_frame_color() -> None: + source_vision_frame = read_image(get_test_example_file('target-240p.jpg')) + target_vision_frame = read_image(get_test_example_file('target-240p-0sat.jpg')) + output_vision_frame = match_frame_color(source_vision_frame, target_vision_frame) + + assert calc_histogram_difference(source_vision_frame, output_vision_frame) > 0.5 diff --git a/tests/test_wording.py b/tests/test_wording.py new file mode 100644 index 0000000000000000000000000000000000000000..5d987f9e155216f0156fb1d264efbef36505b69e --- /dev/null +++ b/tests/test_wording.py @@ -0,0 +1,7 @@ +from facefusion import wording + + +def test_get() -> None: + assert wording.get('python_not_supported') + assert wording.get('help.source_paths') + assert wording.get('invalid') is None