Hyoung-Kyu Song commited on
Commit
16c8067
·
0 Parent(s):

Reinitialize demo with published github repository. With Gradio 4.x

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. CONTRIBUTING.md +23 -0
  3. Dockerfile +9 -0
  4. README.md +84 -0
  5. app.py +105 -0
  6. app.sh +4 -0
  7. checkpoints/.gitkeep +0 -0
  8. config/__init__.py +5 -0
  9. config/gradio.yaml +14 -0
  10. config/nota_wav2lip.yaml +44 -0
  11. data/.gitkeep +0 -0
  12. docker-compose.yml +11 -0
  13. docs/assets/fig5.png +0 -0
  14. docs/description.md +22 -0
  15. docs/footer.md +5 -0
  16. docs/header.md +10 -0
  17. docs/main.css +4 -0
  18. download.py +44 -0
  19. download.sh +7 -0
  20. face_detection/README.md +1 -0
  21. face_detection/__init__.py +7 -0
  22. face_detection/api.py +79 -0
  23. face_detection/detection/__init__.py +1 -0
  24. face_detection/detection/core.py +130 -0
  25. face_detection/detection/sfd/__init__.py +1 -0
  26. face_detection/detection/sfd/bbox.py +129 -0
  27. face_detection/detection/sfd/detect.py +112 -0
  28. face_detection/detection/sfd/net_s3fd.py +129 -0
  29. face_detection/detection/sfd/sfd_detector.py +59 -0
  30. face_detection/models.py +261 -0
  31. face_detection/utils.py +313 -0
  32. inference.py +82 -0
  33. inference.sh +15 -0
  34. nota_wav2lip/__init__.py +2 -0
  35. nota_wav2lip/audio.py +135 -0
  36. nota_wav2lip/demo.py +91 -0
  37. nota_wav2lip/gradio.py +91 -0
  38. nota_wav2lip/inference.py +111 -0
  39. nota_wav2lip/models/__init__.py +3 -0
  40. nota_wav2lip/models/base.py +55 -0
  41. nota_wav2lip/models/conv.py +34 -0
  42. nota_wav2lip/models/util.py +32 -0
  43. nota_wav2lip/models/wav2lip.py +85 -0
  44. nota_wav2lip/models/wav2lip_compressed.py +72 -0
  45. nota_wav2lip/preprocess/__init__.py +2 -0
  46. nota_wav2lip/preprocess/core.py +98 -0
  47. nota_wav2lip/preprocess/ffmpeg.py +5 -0
  48. nota_wav2lip/preprocess/lrs3_download.py +259 -0
  49. nota_wav2lip/util.py +5 -0
  50. nota_wav2lip/video.py +68 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
CONTRIBUTING.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to this repository
2
+
3
+ ## Install linter
4
+
5
+ First of all, you need to install `ruff` package to verify that you passed all conditions for formatting.
6
+
7
+ ```
8
+ pip install ruff==0.0.287
9
+ ```
10
+
11
+ ### Apply linter before PR
12
+
13
+ Please run the ruff check with the following command:
14
+
15
+ ```
16
+ ruff check .
17
+ ```
18
+
19
+ ### Auto-fix with fixable errors
20
+
21
+ ```
22
+ ruff check . --fix
23
+ ```
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:22.03-py3
2
+
3
+ ARG DEBIAN_FRONTEND=noninteractive
4
+ RUN apt-get update
5
+ RUN apt-get install ffmpeg libsm6 libxext6 tmux git -y
6
+
7
+ WORKDIR /workspace
8
+ COPY requirements.txt .
9
+ RUN pip install --no-cache -r requirements.txt
README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Compressed Wav2Lip
3
+ emoji: 🌟
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.13.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ ---
12
+
13
+ # 28× Compressed Wav2Lip by Nota AI
14
+
15
+ Official codebase for [**Accelerating Speech-Driven Talking Face Generation with 28× Compressed Wav2Lip**](https://arxiv.org/abs/2304.00471).
16
+
17
+ - Presented at [ICCV'23 Demo](https://iccv2023.thecvf.com/demos-111.php) Track; [On-Device Intelligence Workshop](https://sites.google.com/g.harvard.edu/on-device-workshop-23/home) @ MLSys'23; [NVIDIA GTC 2023](https://www.nvidia.com/en-us/on-demand/search/?facet.mimetype[]=event%20session&layout=list&page=1&q=52409&sort=relevance&sortDir=desc) Poster.
18
+
19
+
20
+ ## Installation
21
+ #### Docker (recommended)
22
+ ```bash
23
+ git clone https://github.com/Nota-NetsPresso/nota-wav2lip.git
24
+ cd nota-wav2lip
25
+ docker compose run --service-ports --name nota-compressed-wav2lip compressed-wav2lip bash
26
+ ```
27
+
28
+ #### Conda
29
+ <details>
30
+ <summary>Click</summary>
31
+
32
+ ```bash
33
+ git clone https://github.com/Nota-NetsPresso/nota-wav2lip.git
34
+ cd nota-wav2lip
35
+ apt-get update
36
+ apt-get install ffmpeg libsm6 libxext6 tmux git -y
37
+ conda create -n nota-wav2lip python=3.9
38
+ conda activate nota-wav2lip
39
+ pip install -r requirements.txt
40
+ ```
41
+ </details>
42
+
43
+ ## Gradio Demo
44
+ Use the below script to run the [nota-ai/compressed-wav2lip demo](https://huggingface.co/spaces/nota-ai/compressed-wav2lip). The models and sample data will be downloaded automatically.
45
+
46
+ ```bash
47
+ bash app.sh
48
+ ```
49
+
50
+ ## Inference
51
+ (1) Download YouTube videos in the LRS3-TED label text file and preprocess them properly.
52
+ - Download `lrs3_v0.4_txt.zip` from [this link](https://mmai.io/datasets/lip_reading/).
53
+ - Unzip the file and make a folder structure: `./data/lrs3_v0.4_txt/lrs3_v0.4/test`
54
+ - Run `bash download.sh`
55
+ - Run `bash preprocess.sh`
56
+
57
+ (2) Run the script to compare the original Wav2Lip with Nota's compressed version.
58
+
59
+ ```bash
60
+ bash inference.sh
61
+ ```
62
+
63
+ ## License
64
+ - All rights related to this repository and the compressed models are reserved by Nota Inc.
65
+ - The intended use is strictly limited to research and non-commercial projects.
66
+
67
+ ## Contact
68
+ - To obtain compression code and assistance, kindly contact Nota AI ([email protected]). These are provided as part of our business solutions.
69
+ - For Q&A about this repo, use this board: [Nota-NetsPresso/discussions](https://github.com/orgs/Nota-NetsPresso/discussions)
70
+
71
+ ## Acknowledgment
72
+ - [NVIDIA Applied Research Accelerator Program](https://www.nvidia.com/en-us/industries/higher-education-research/applied-research-program/) for supporting this research.
73
+ - [Wav2Lip](https://github.com/Rudrabha/Wav2Lip) and [LRS3-TED](https://www.robots.ox.ac.uk/~vgg/data/lip_reading/) for facilitating the development of the original Wav2Lip.
74
+
75
+ ## Citation
76
+ ```bibtex
77
+ @article{kim2023unified,
78
+ title={A Unified Compression Framework for Efficient Speech-Driven Talking-Face Generation},
79
+ author={Kim, Bo-Kyeong and Kang, Jaemin and Seo, Daeun and Park, Hancheol and Choi, Shinkook and Song, Hyoung-Kyu and Kim, Hyungshin and Lim, Sungsu},
80
+ journal={MLSys Workshop on On-Device Intelligence (ODIW)},
81
+ year={2023},
82
+ url={https://arxiv.org/abs/2304.00471}
83
+ }
84
+ ```
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
+
7
+ from config import hparams as hp
8
+ from config import hparams_gradio as hp_gradio
9
+ from nota_wav2lip import Wav2LipModelComparisonGradio
10
+
11
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ device = hp_gradio.device
13
+ print(f'Using {device} for inference.')
14
+ video_label_dict = hp_gradio.sample.video
15
+ audio_label_dict = hp_gradio.sample.audio
16
+
17
+ LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
18
+ LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
19
+ LRS_INFERENCE_SAMPLE = os.getenv('LRS_INFERENCE_SAMPLE', None)
20
+
21
+ if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
22
+ subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
23
+ if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
24
+ subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)
25
+
26
+ path_inference_sample = "sample.tar.gz"
27
+ if not Path(path_inference_sample).exists() and LRS_INFERENCE_SAMPLE is not None:
28
+ subprocess.call(f"wget --no-check-certificate -O {path_inference_sample} {LRS_INFERENCE_SAMPLE}", shell=True)
29
+ subprocess.call(f"tar -zxvf {path_inference_sample}", shell=True)
30
+
31
+
32
+ if __name__ == "__main__":
33
+
34
+ servicer = Wav2LipModelComparisonGradio(
35
+ device=device,
36
+ video_label_dict=video_label_dict,
37
+ audio_label_list=audio_label_dict,
38
+ default_video='v1',
39
+ default_audio='a1'
40
+ )
41
+
42
+ for video_name in sorted(video_label_dict):
43
+ video_stem = Path(video_label_dict[video_name])
44
+ servicer.update_video(video_stem, video_stem.with_suffix('.json'),
45
+ name=video_name)
46
+
47
+ for audio_name in sorted(audio_label_dict):
48
+ audio_path = Path(audio_label_dict[audio_name])
49
+ servicer.update_audio(audio_path, name=audio_name)
50
+
51
+ with gr.Blocks(theme='nota-ai/theme', css=Path('docs/main.css').read_text()) as demo:
52
+ gr.Markdown(Path('docs/header.md').read_text())
53
+ gr.Markdown(Path('docs/description.md').read_text())
54
+ with gr.Row():
55
+ with gr.Column(variant='panel'):
56
+
57
+ gr.Markdown('## Select input video and audio', sanitize_html=False)
58
+ # Define samples
59
+ sample_video = gr.Video(interactive=False, label="Input Video")
60
+ sample_audio = gr.Audio(interactive=False, label="Input Audio")
61
+
62
+ # Define radio inputs
63
+ video_selection = gr.components.Radio(video_label_dict,
64
+ type='value', label="Select an input video:")
65
+ audio_selection = gr.components.Radio(audio_label_dict,
66
+ type='value', label="Select an input audio:")
67
+ # Define button inputs
68
+ with gr.Row(equal_height=True):
69
+ generate_original_button = gr.Button(value="Generate with Original Model", variant="primary")
70
+ generate_compressed_button = gr.Button(value="Generate with Compressed Model", variant="primary")
71
+ with gr.Column(variant='panel'):
72
+ # Define original model output components
73
+ gr.Markdown('## Original Wav2Lip')
74
+ original_model_output = gr.Video(label="Original Model", interactive=False)
75
+ with gr.Column():
76
+ with gr.Row(equal_height=True):
77
+ original_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
78
+ original_model_fps = gr.Textbox(value="", label="FPS")
79
+ original_model_params = gr.Textbox(value=servicer.params['wav2lip'], label="# Parameters")
80
+ with gr.Column(variant='panel'):
81
+ # Define compressed model output components
82
+ gr.Markdown('## Compressed Wav2Lip (Ours)')
83
+ compressed_model_output = gr.Video(label="Compressed Model", interactive=False)
84
+ with gr.Column():
85
+ with gr.Row(equal_height=True):
86
+ compressed_model_inference_time = gr.Textbox(value="", label="Total inference time (sec)")
87
+ compressed_model_fps = gr.Textbox(value="", label="FPS")
88
+ compressed_model_params = gr.Textbox(value=servicer.params['nota_wav2lip'], label="# Parameters")
89
+
90
+ # Switch video and audio samples when selecting the raido button
91
+ video_selection.change(fn=servicer.switch_video_samples, inputs=video_selection, outputs=sample_video)
92
+ audio_selection.change(fn=servicer.switch_audio_samples, inputs=audio_selection, outputs=sample_audio)
93
+
94
+ # Click the generate button for original model
95
+ generate_original_button.click(servicer.generate_original_model,
96
+ inputs=[video_selection, audio_selection],
97
+ outputs=[original_model_output, original_model_inference_time, original_model_fps])
98
+ # Click the generate button for compressed model
99
+ generate_compressed_button.click(servicer.generate_compressed_model,
100
+ inputs=[video_selection, audio_selection],
101
+ outputs=[compressed_model_output, compressed_model_inference_time, compressed_model_fps])
102
+
103
+ gr.Markdown(Path('docs/footer.md').read_text())
104
+
105
+ demo.queue().launch()
app.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ export LRS_ORIGINAL_URL=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/compressed-wav2lip/lrs3-wav2lip.pth && \
2
+ export LRS_COMPRESSED_URL=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/compressed-wav2lip/lrs3-nota-wav2lip.pth && \
3
+ export LRS_INFERENCE_SAMPLE=https://netspresso-huggingface-demo-checkpoint.s3.us-east-2.amazonaws.com/data/compressed-wav2lip-inference/sample.tar.gz && \
4
+ python app.py
checkpoints/.gitkeep ADDED
File without changes
config/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from omegaconf import DictConfig, OmegaConf
2
+
3
+ hparams: DictConfig = OmegaConf.load("config/nota_wav2lip.yaml")
4
+
5
+ hparams_gradio: DictConfig = OmegaConf.load("config/gradio.yaml")
config/gradio.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ device: cpu
2
+ sample:
3
+ video:
4
+ v1: "sample/2145_orig"
5
+ v2: "sample/2942_orig"
6
+ v3: "sample/4598_orig"
7
+ v4: "sample/4653_orig"
8
+ v5: "sample/13692_orig"
9
+ audio:
10
+ a1: "sample/1673_orig.wav"
11
+ a2: "sample/9948_orig.wav"
12
+ a3: "sample/11028_orig.wav"
13
+ a4: "sample/12640_orig.wav"
14
+ a5: "sample/5592_orig.wav"
config/nota_wav2lip.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ inference:
3
+ batch_size: 1
4
+ frame:
5
+ h: 224
6
+ w: 224
7
+ model:
8
+ wav2lip:
9
+ checkpoint: "checkpoints/lrs3-wav2lip.pth"
10
+ nota_wav2lip:
11
+ checkpoint: "checkpoints/lrs3-nota-wav2lip.pth"
12
+
13
+ audio:
14
+ num_mels: 80
15
+ rescale: True
16
+ rescaling_max: 0.9
17
+
18
+ use_lws: False
19
+
20
+ n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter
21
+ hop_size: 200 # For 16000Hz, 200 : 12.5 ms (0.0125 * sample_rate)
22
+ win_size: 800 # For 16000Hz, 800 : 50 ms (If None, win_size : n_fft) (0.05 * sample_rate)
23
+ sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i <filename>)
24
+
25
+ frame_shift_ms: ~
26
+
27
+ signal_normalization: True
28
+ allow_clipping_in_normalization: True
29
+ symmetric_mels: True
30
+ max_abs_value: 4.
31
+ preemphasize: True
32
+ preemphasis: 0.97
33
+
34
+ # Limits
35
+ min_level_db: -100
36
+ ref_level_db: 20
37
+ fmin: 55
38
+ fmax: 7600
39
+
40
+ face:
41
+ video_fps: 25
42
+ img_size: 96
43
+ mel_step_size: 16
44
+
data/.gitkeep ADDED
File without changes
docker-compose.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.9"
2
+ services:
3
+ compressed-wav2lip:
4
+ image: nota-compressed-wav2lip:dev
5
+ build: ./
6
+ container_name: nota-compressed-wav2lip
7
+ ipc: host
8
+ ports:
9
+ - "7860:7860"
10
+ volumes:
11
+ - ./:/workspace
docs/assets/fig5.png ADDED
docs/description.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This demo showcases a lightweight model for speech-driven talking-face synthesis, a **28× Compressed Wav2Lip**. The key features of our approach are:
2
+ - compact generator built by removing the residual blocks and reducing the channel width from Wav2Lip.
3
+ - knowledge distillation to effectively train the small-capacity generator without adversarial learning.
4
+ - selective quantization to accelerate inference on edge GPUs without noticeable performance degradation.
5
+
6
+ <!-- To demonstrate the efficacy of our approach, we provide a latency comparison of different precisions on NVIDIA Jetson edge GPUs in Figure 5. Our approach achieves a remarkable 8× to 17× speedup with FP16 precision, and a 19× speedup on Xavier NX with mixed precision. -->
7
+ The below figure shows a latency comparison at different precisions on NVIDIA Jetson edge GPUs, highlighting a 8× to 17× speedup at FP16 and a 19× speedup on Xavier NX at mixed precision.
8
+
9
+ <center>
10
+ <img alt="compressed-wav2lip-performance" src="https://huggingface.co/spaces/nota-ai/compressed-wav2lip/resolve/2b86e2aa4921d3422f0769ed02dce9898d1e0470/docs/assets/fig5.png" width="70%" />
11
+ </center>
12
+
13
+ <br/>
14
+
15
+ The generation speed may vary depending on network traffic. Nevertheless, our compresed Wav2Lip _consistently_ delivers a faster inference than the original model, while maintaining similar visual quality. Different from the paper, in this demo, we measure **total processing time** and **FPS** throughout loading the preprocessed video and audio, generating with the model, and merging lip-synced facial images with the original video.
16
+
17
+ <br/>
18
+
19
+
20
+ ### Notice
21
+ - This work was accepted to [Demo] [**ICCV 2023 Demo Track**](https://iccv2023.thecvf.com/demos-111.php); [[Paper](https://arxiv.org/abs/2304.00471)] [**On-Device Intelligence Workshop (ODIW) @ MLSys 2023**](https://sites.google.com/g.harvard.edu/on-device-workshop-23/home); [Poster] [**NVIDIA GPU Technology Conference (GTC) as Poster Spotlight**](https://www.nvidia.com/en-us/on-demand/search/?facet.mimetype[]=event%20session&layout=list&page=1&q=52409&sort=relevance&sortDir=desc).
22
+ - We thank [NVIDIA Applied Research Accelerator Program](https://www.nvidia.com/en-us/industries/higher-education-research/applied-research-program/) for supporting this research and [Wav2Lip's Authors](https://github.com/Rudrabha/Wav2Lip) for their pioneering research.
docs/footer.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <p align="center">
2
+ <a href="https://netspresso.ai/"><img src="https://huggingface.co/spaces/nota-ai/theme/resolve/main/docs/logo/nota_favicon_800x800.png" width="96px" height="96px"></a>
3
+ </p>
4
+
5
+ <br/>
docs/header.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # <center>Lightweight Speech-Driven Talking-Face Synthesis Demo</center>
2
+
3
+ <br/>
4
+
5
+ <p align="center">
6
+ <a href="https://arxiv.org/abs/2304.00471"><img src="https://img.shields.io/badge/arXiv-2304.00471-b31b1b.svg?style=flat-square" style="display:inline;"></a>
7
+ <a href="https://huggingface.co/spaces/nota-ai/efficient_wav2lip"><img src="https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fnota-ai%2Fefficient_wav2lip&count_bg=%23325AC8&title_bg=%23112344&icon=&icon_color=%23E7E7E7&title=HITS&edge_flat=true" style="display:inline;"></a>
8
+ </p>
9
+
10
+ <br/>
docs/main.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ h1, h2, h3 {
2
+ text-align: center;
3
+ display:block;
4
+ }
download.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from nota_wav2lip.preprocess import get_cropped_face_from_lrs3_label
4
+
5
+
6
+ def parse_args():
7
+
8
+ parser = argparse.ArgumentParser(description="NotaWav2Lip: Get LRS3 video sample with the label text file")
9
+
10
+ parser.add_argument(
11
+ '-i',
12
+ '--input-file',
13
+ type=str,
14
+ required=True,
15
+ help="Path of the label text file downloaded from https://mmai.io/datasets/lip_reading"
16
+ )
17
+
18
+ parser.add_argument(
19
+ '-o',
20
+ '--output-dir',
21
+ type=str,
22
+ default="sample_video_lrs3",
23
+ help="Output directory to save the result. Defaults: sample_video_lrs3"
24
+ )
25
+
26
+ parser.add_argument(
27
+ '--ignore-cache',
28
+ action='store_true',
29
+ help="Whether to force downloading and resampling video and overwrite pre-existing files"
30
+ )
31
+
32
+ args = parser.parse_args()
33
+
34
+ return args
35
+
36
+
37
+ if __name__ == '__main__':
38
+ args = parse_args()
39
+
40
+ get_cropped_face_from_lrs3_label(
41
+ args.input_file,
42
+ video_root_dir=args.output_dir,
43
+ ignore_cache = args.ignore_cache
44
+ )
download.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # example for audio source
2
+ python download.py\
3
+ -i data/lrs3_v0.4_txt/lrs3_v0.4/test/sxnlvwprfSc/00007.txt
4
+
5
+ # example for video source
6
+ python download.py\
7
+ -i data/lrs3_v0.4_txt/lrs3_v0.4/test/Li4S1yyrsTI/00010.txt
face_detection/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
face_detection/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ __author__ = """Adrian Bulat"""
4
+ __email__ = '[email protected]'
5
+ __version__ = '1.0.1'
6
+
7
+ from .api import FaceAlignment, LandmarksType, NetworkSize
face_detection/api.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import torch
4
+ from torch.utils.model_zoo import load_url
5
+ from enum import Enum
6
+ import numpy as np
7
+ import cv2
8
+ try:
9
+ import urllib.request as request_file
10
+ except BaseException:
11
+ import urllib as request_file
12
+
13
+ from .models import FAN, ResNetDepth
14
+ from .utils import *
15
+
16
+
17
+ class LandmarksType(Enum):
18
+ """Enum class defining the type of landmarks to detect.
19
+
20
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
21
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
22
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
23
+
24
+ """
25
+ _2D = 1
26
+ _2halfD = 2
27
+ _3D = 3
28
+
29
+
30
+ class NetworkSize(Enum):
31
+ # TINY = 1
32
+ # SMALL = 2
33
+ # MEDIUM = 3
34
+ LARGE = 4
35
+
36
+ def __new__(cls, value):
37
+ member = object.__new__(cls)
38
+ member._value_ = value
39
+ return member
40
+
41
+ def __int__(self):
42
+ return self.value
43
+
44
+ ROOT = os.path.dirname(os.path.abspath(__file__))
45
+
46
+ class FaceAlignment:
47
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
48
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
49
+ self.device = device
50
+ self.flip_input = flip_input
51
+ self.landmarks_type = landmarks_type
52
+ self.verbose = verbose
53
+
54
+ network_size = int(network_size)
55
+
56
+ if 'cuda' in device:
57
+ torch.backends.cudnn.benchmark = True
58
+
59
+ # Get the face detector
60
+ face_detector_module = __import__('face_detection.detection.' + face_detector,
61
+ globals(), locals(), [face_detector], 0)
62
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
63
+
64
+ def get_detections_for_batch(self, images):
65
+ images = images[..., ::-1]
66
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
67
+ results = []
68
+
69
+ for i, d in enumerate(detected_faces):
70
+ if len(d) == 0:
71
+ results.append(None)
72
+ continue
73
+ d = d[0]
74
+ d = np.clip(d, 0, None)
75
+
76
+ x1, y1, x2, y2 = map(int, d[:-1])
77
+ results.append((x1, y1, x2, y2))
78
+
79
+ return results
face_detection/detection/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FaceDetector
face_detection/detection/core.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import glob
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ import torch
6
+ import cv2
7
+
8
+
9
+ class FaceDetector(object):
10
+ """An abstract class representing a face detector.
11
+
12
+ Any other face detection implementation must subclass it. All subclasses
13
+ must implement ``detect_from_image``, that return a list of detected
14
+ bounding boxes. Optionally, for speed considerations detect from path is
15
+ recommended.
16
+ """
17
+
18
+ def __init__(self, device, verbose):
19
+ self.device = device
20
+ self.verbose = verbose
21
+
22
+ if verbose:
23
+ if 'cpu' in device:
24
+ logger = logging.getLogger(__name__)
25
+ logger.warning("Detection running on CPU, this may be potentially slow.")
26
+
27
+ if 'cpu' not in device and 'cuda' not in device:
28
+ if verbose:
29
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
30
+ raise ValueError
31
+
32
+ def detect_from_image(self, tensor_or_path):
33
+ """Detects faces in a given image.
34
+
35
+ This function detects the faces present in a provided BGR(usually)
36
+ image. The input can be either the image itself or the path to it.
37
+
38
+ Arguments:
39
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
40
+ to an image or the image itself.
41
+
42
+ Example::
43
+
44
+ >>> path_to_image = 'data/image_01.jpg'
45
+ ... detected_faces = detect_from_image(path_to_image)
46
+ [A list of bounding boxes (x1, y1, x2, y2)]
47
+ >>> image = cv2.imread(path_to_image)
48
+ ... detected_faces = detect_from_image(image)
49
+ [A list of bounding boxes (x1, y1, x2, y2)]
50
+
51
+ """
52
+ raise NotImplementedError
53
+
54
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
55
+ """Detects faces from all the images present in a given directory.
56
+
57
+ Arguments:
58
+ path {string} -- a string containing a path that points to the folder containing the images
59
+
60
+ Keyword Arguments:
61
+ extensions {list} -- list of string containing the extensions to be
62
+ consider in the following format: ``.extension_name`` (default:
63
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
64
+ folder recursively (default: {False}) show_progress_bar {bool} --
65
+ display a progressbar (default: {True})
66
+
67
+ Example:
68
+ >>> directory = 'data'
69
+ ... detected_faces = detect_from_directory(directory)
70
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
71
+
72
+ """
73
+ if self.verbose:
74
+ logger = logging.getLogger(__name__)
75
+
76
+ if len(extensions) == 0:
77
+ if self.verbose:
78
+ logger.error("Expected at list one extension, but none was received.")
79
+ raise ValueError
80
+
81
+ if self.verbose:
82
+ logger.info("Constructing the list of images.")
83
+ additional_pattern = '/**/*' if recursive else '/*'
84
+ files = []
85
+ for extension in extensions:
86
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
87
+
88
+ if self.verbose:
89
+ logger.info("Finished searching for images. %s images found", len(files))
90
+ logger.info("Preparing to run the detection.")
91
+
92
+ predictions = {}
93
+ for image_path in tqdm(files, disable=not show_progress_bar):
94
+ if self.verbose:
95
+ logger.info("Running the face detector on image: %s", image_path)
96
+ predictions[image_path] = self.detect_from_image(image_path)
97
+
98
+ if self.verbose:
99
+ logger.info("The detector was successfully run on all %s images", len(files))
100
+
101
+ return predictions
102
+
103
+ @property
104
+ def reference_scale(self):
105
+ raise NotImplementedError
106
+
107
+ @property
108
+ def reference_x_shift(self):
109
+ raise NotImplementedError
110
+
111
+ @property
112
+ def reference_y_shift(self):
113
+ raise NotImplementedError
114
+
115
+ @staticmethod
116
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
117
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
118
+
119
+ Arguments:
120
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
121
+ """
122
+ if isinstance(tensor_or_path, str):
123
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
124
+ elif torch.is_tensor(tensor_or_path):
125
+ # Call cpu in case its coming from cuda
126
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
127
+ elif isinstance(tensor_or_path, np.ndarray):
128
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
129
+ else:
130
+ raise TypeError
face_detection/detection/sfd/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sfd_detector import SFDDetector as FaceDetector
face_detection/detection/sfd/bbox.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import cv2
5
+ import random
6
+ import datetime
7
+ import time
8
+ import math
9
+ import argparse
10
+ import numpy as np
11
+ import torch
12
+
13
+ try:
14
+ from iou import IOU
15
+ except BaseException:
16
+ # IOU cython speedup 10x
17
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
18
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
19
+ sb = abs((bx2 - bx1) * (by2 - by1))
20
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
21
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
22
+ w = x2 - x1
23
+ h = y2 - y1
24
+ if w < 0 or h < 0:
25
+ return 0.0
26
+ else:
27
+ return 1.0 * w * h / (sa + sb - w * h)
28
+
29
+
30
+ def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
31
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
32
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
33
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
34
+ return dx, dy, dw, dh
35
+
36
+
37
+ def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
38
+ xc, yc = dx * aww + axc, dy * ahh + ayc
39
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
40
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
41
+ return x1, y1, x2, y2
42
+
43
+
44
+ def nms(dets, thresh):
45
+ if 0 == len(dets):
46
+ return []
47
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
48
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
49
+ order = scores.argsort()[::-1]
50
+
51
+ keep = []
52
+ while order.size > 0:
53
+ i = order[0]
54
+ keep.append(i)
55
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
56
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
57
+
58
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
59
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
60
+
61
+ inds = np.where(ovr <= thresh)[0]
62
+ order = order[inds + 1]
63
+
64
+ return keep
65
+
66
+
67
+ def encode(matched, priors, variances):
68
+ """Encode the variances from the priorbox layers into the ground truth boxes
69
+ we have matched (based on jaccard overlap) with the prior boxes.
70
+ Args:
71
+ matched: (tensor) Coords of ground truth for each prior in point-form
72
+ Shape: [num_priors, 4].
73
+ priors: (tensor) Prior boxes in center-offset form
74
+ Shape: [num_priors,4].
75
+ variances: (list[float]) Variances of priorboxes
76
+ Return:
77
+ encoded boxes (tensor), Shape: [num_priors, 4]
78
+ """
79
+
80
+ # dist b/t match center and prior's center
81
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
82
+ # encode variance
83
+ g_cxcy /= (variances[0] * priors[:, 2:])
84
+ # match wh / prior wh
85
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
86
+ g_wh = torch.log(g_wh) / variances[1]
87
+ # return target for smooth_l1_loss
88
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
89
+
90
+
91
+ def decode(loc, priors, variances):
92
+ """Decode locations from predictions using priors to undo
93
+ the encoding we did for offset regression at train time.
94
+ Args:
95
+ loc (tensor): location predictions for loc layers,
96
+ Shape: [num_priors,4]
97
+ priors (tensor): Prior boxes in center-offset form.
98
+ Shape: [num_priors,4].
99
+ variances: (list[float]) Variances of priorboxes
100
+ Return:
101
+ decoded bounding box predictions
102
+ """
103
+
104
+ boxes = torch.cat((
105
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
106
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
107
+ boxes[:, :2] -= boxes[:, 2:] / 2
108
+ boxes[:, 2:] += boxes[:, :2]
109
+ return boxes
110
+
111
+ def batch_decode(loc, priors, variances):
112
+ """Decode locations from predictions using priors to undo
113
+ the encoding we did for offset regression at train time.
114
+ Args:
115
+ loc (tensor): location predictions for loc layers,
116
+ Shape: [num_priors,4]
117
+ priors (tensor): Prior boxes in center-offset form.
118
+ Shape: [num_priors,4].
119
+ variances: (list[float]) Variances of priorboxes
120
+ Return:
121
+ decoded bounding box predictions
122
+ """
123
+
124
+ boxes = torch.cat((
125
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
126
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
127
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
128
+ boxes[:, :, 2:] += boxes[:, :, :2]
129
+ return boxes
face_detection/detection/sfd/detect.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ import os
5
+ import sys
6
+ import cv2
7
+ import random
8
+ import datetime
9
+ import math
10
+ import argparse
11
+ import numpy as np
12
+
13
+ import scipy.io as sio
14
+ import zipfile
15
+ from .net_s3fd import s3fd
16
+ from .bbox import *
17
+
18
+
19
+ def detect(net, img, device):
20
+ img = img - np.array([104, 117, 123])
21
+ img = img.transpose(2, 0, 1)
22
+ img = img.reshape((1,) + img.shape)
23
+
24
+ if 'cuda' in device:
25
+ torch.backends.cudnn.benchmark = True
26
+
27
+ img = torch.from_numpy(img).float().to(device)
28
+ BB, CC, HH, WW = img.size()
29
+ with torch.no_grad():
30
+ olist = net(img)
31
+
32
+ bboxlist = []
33
+ for i in range(len(olist) // 2):
34
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
35
+ olist = [oelem.data.cpu() for oelem in olist]
36
+ for i in range(len(olist) // 2):
37
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
38
+ FB, FC, FH, FW = ocls.size() # feature map size
39
+ stride = 2**(i + 2) # 4,8,16,32,64,128
40
+ anchor = stride * 4
41
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
42
+ for Iindex, hindex, windex in poss:
43
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
44
+ score = ocls[0, 1, hindex, windex]
45
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
46
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
47
+ variances = [0.1, 0.2]
48
+ box = decode(loc, priors, variances)
49
+ x1, y1, x2, y2 = box[0] * 1.0
50
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
51
+ bboxlist.append([x1, y1, x2, y2, score])
52
+ bboxlist = np.array(bboxlist)
53
+ if 0 == len(bboxlist):
54
+ bboxlist = np.zeros((1, 5))
55
+
56
+ return bboxlist
57
+
58
+ def batch_detect(net, imgs, device):
59
+ imgs = imgs - np.array([104, 117, 123])
60
+ imgs = imgs.transpose(0, 3, 1, 2)
61
+
62
+ if 'cuda' in device:
63
+ torch.backends.cudnn.benchmark = True
64
+
65
+ imgs = torch.from_numpy(imgs).float().to(device)
66
+ BB, CC, HH, WW = imgs.size()
67
+ with torch.no_grad():
68
+ olist = net(imgs)
69
+
70
+ bboxlist = []
71
+ for i in range(len(olist) // 2):
72
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
73
+ olist = [oelem.data.cpu() for oelem in olist]
74
+ for i in range(len(olist) // 2):
75
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
76
+ FB, FC, FH, FW = ocls.size() # feature map size
77
+ stride = 2**(i + 2) # 4,8,16,32,64,128
78
+ anchor = stride * 4
79
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
80
+ for Iindex, hindex, windex in poss:
81
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
82
+ score = ocls[:, 1, hindex, windex]
83
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
84
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
85
+ variances = [0.1, 0.2]
86
+ box = batch_decode(loc, priors, variances)
87
+ box = box[:, 0] * 1.0
88
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
89
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
90
+ bboxlist = np.array(bboxlist)
91
+ if 0 == len(bboxlist):
92
+ bboxlist = np.zeros((1, BB, 5))
93
+
94
+ return bboxlist
95
+
96
+ def flip_detect(net, img, device):
97
+ img = cv2.flip(img, 1)
98
+ b = detect(net, img, device)
99
+
100
+ bboxlist = np.zeros(b.shape)
101
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
102
+ bboxlist[:, 1] = b[:, 1]
103
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
104
+ bboxlist[:, 3] = b[:, 3]
105
+ bboxlist[:, 4] = b[:, 4]
106
+ return bboxlist
107
+
108
+
109
+ def pts_to_bb(pts):
110
+ min_x, min_y = np.min(pts, axis=0)
111
+ max_x, max_y = np.max(pts, axis=0)
112
+ return np.array([min_x, min_y, max_x, max_y])
face_detection/detection/sfd/net_s3fd.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class L2Norm(nn.Module):
7
+ def __init__(self, n_channels, scale=1.0):
8
+ super(L2Norm, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.scale = scale
11
+ self.eps = 1e-10
12
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
13
+ self.weight.data *= 0.0
14
+ self.weight.data += self.scale
15
+
16
+ def forward(self, x):
17
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
18
+ x = x / norm * self.weight.view(1, -1, 1, 1)
19
+ return x
20
+
21
+
22
+ class s3fd(nn.Module):
23
+ def __init__(self):
24
+ super(s3fd, self).__init__()
25
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
26
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
27
+
28
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
29
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
30
+
31
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
32
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
33
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
34
+
35
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
36
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
37
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
38
+
39
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
40
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
41
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
42
+
43
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
44
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
45
+
46
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
47
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
48
+
49
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
50
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
51
+
52
+ self.conv3_3_norm = L2Norm(256, scale=10)
53
+ self.conv4_3_norm = L2Norm(512, scale=8)
54
+ self.conv5_3_norm = L2Norm(512, scale=5)
55
+
56
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
57
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
58
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
59
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
60
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
61
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
62
+
63
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
64
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
65
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
66
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
67
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
68
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
69
+
70
+ def forward(self, x):
71
+ h = F.relu(self.conv1_1(x))
72
+ h = F.relu(self.conv1_2(h))
73
+ h = F.max_pool2d(h, 2, 2)
74
+
75
+ h = F.relu(self.conv2_1(h))
76
+ h = F.relu(self.conv2_2(h))
77
+ h = F.max_pool2d(h, 2, 2)
78
+
79
+ h = F.relu(self.conv3_1(h))
80
+ h = F.relu(self.conv3_2(h))
81
+ h = F.relu(self.conv3_3(h))
82
+ f3_3 = h
83
+ h = F.max_pool2d(h, 2, 2)
84
+
85
+ h = F.relu(self.conv4_1(h))
86
+ h = F.relu(self.conv4_2(h))
87
+ h = F.relu(self.conv4_3(h))
88
+ f4_3 = h
89
+ h = F.max_pool2d(h, 2, 2)
90
+
91
+ h = F.relu(self.conv5_1(h))
92
+ h = F.relu(self.conv5_2(h))
93
+ h = F.relu(self.conv5_3(h))
94
+ f5_3 = h
95
+ h = F.max_pool2d(h, 2, 2)
96
+
97
+ h = F.relu(self.fc6(h))
98
+ h = F.relu(self.fc7(h))
99
+ ffc7 = h
100
+ h = F.relu(self.conv6_1(h))
101
+ h = F.relu(self.conv6_2(h))
102
+ f6_2 = h
103
+ h = F.relu(self.conv7_1(h))
104
+ h = F.relu(self.conv7_2(h))
105
+ f7_2 = h
106
+
107
+ f3_3 = self.conv3_3_norm(f3_3)
108
+ f4_3 = self.conv4_3_norm(f4_3)
109
+ f5_3 = self.conv5_3_norm(f5_3)
110
+
111
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
112
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
113
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
114
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
115
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
116
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
117
+ cls4 = self.fc7_mbox_conf(ffc7)
118
+ reg4 = self.fc7_mbox_loc(ffc7)
119
+ cls5 = self.conv6_2_mbox_conf(f6_2)
120
+ reg5 = self.conv6_2_mbox_loc(f6_2)
121
+ cls6 = self.conv7_2_mbox_conf(f7_2)
122
+ reg6 = self.conv7_2_mbox_loc(f7_2)
123
+
124
+ # max-out background label
125
+ chunk = torch.chunk(cls1, 4, 1)
126
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
127
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
128
+
129
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
face_detection/detection/sfd/sfd_detector.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ from torch.utils.model_zoo import load_url
4
+
5
+ from ..core import FaceDetector
6
+
7
+ from .net_s3fd import s3fd
8
+ from .bbox import *
9
+ from .detect import *
10
+
11
+ models_urls = {
12
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
13
+ }
14
+
15
+
16
+ class SFDDetector(FaceDetector):
17
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
18
+ super(SFDDetector, self).__init__(device, verbose)
19
+
20
+ # Initialise the face detector
21
+ if not os.path.isfile(path_to_detector):
22
+ model_weights = load_url(models_urls['s3fd'])
23
+ else:
24
+ model_weights = torch.load(path_to_detector)
25
+
26
+ self.face_detector = s3fd()
27
+ self.face_detector.load_state_dict(model_weights)
28
+ self.face_detector.to(device)
29
+ self.face_detector.eval()
30
+
31
+ def detect_from_image(self, tensor_or_path):
32
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
33
+
34
+ bboxlist = detect(self.face_detector, image, device=self.device)
35
+ keep = nms(bboxlist, 0.3)
36
+ bboxlist = bboxlist[keep, :]
37
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
38
+
39
+ return bboxlist
40
+
41
+ def detect_from_batch(self, images):
42
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
43
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
44
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
45
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
46
+
47
+ return bboxlists
48
+
49
+ @property
50
+ def reference_scale(self):
51
+ return 195
52
+
53
+ @property
54
+ def reference_x_shift(self):
55
+ return 0
56
+
57
+ @property
58
+ def reference_y_shift(self):
59
+ return 0
face_detection/models.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
8
+ "3x3 convolution with padding"
9
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
10
+ stride=strd, padding=padding, bias=bias)
11
+
12
+
13
+ class ConvBlock(nn.Module):
14
+ def __init__(self, in_planes, out_planes):
15
+ super(ConvBlock, self).__init__()
16
+ self.bn1 = nn.BatchNorm2d(in_planes)
17
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
18
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
19
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
20
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
21
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
22
+
23
+ if in_planes != out_planes:
24
+ self.downsample = nn.Sequential(
25
+ nn.BatchNorm2d(in_planes),
26
+ nn.ReLU(True),
27
+ nn.Conv2d(in_planes, out_planes,
28
+ kernel_size=1, stride=1, bias=False),
29
+ )
30
+ else:
31
+ self.downsample = None
32
+
33
+ def forward(self, x):
34
+ residual = x
35
+
36
+ out1 = self.bn1(x)
37
+ out1 = F.relu(out1, True)
38
+ out1 = self.conv1(out1)
39
+
40
+ out2 = self.bn2(out1)
41
+ out2 = F.relu(out2, True)
42
+ out2 = self.conv2(out2)
43
+
44
+ out3 = self.bn3(out2)
45
+ out3 = F.relu(out3, True)
46
+ out3 = self.conv3(out3)
47
+
48
+ out3 = torch.cat((out1, out2, out3), 1)
49
+
50
+ if self.downsample is not None:
51
+ residual = self.downsample(residual)
52
+
53
+ out3 += residual
54
+
55
+ return out3
56
+
57
+
58
+ class Bottleneck(nn.Module):
59
+
60
+ expansion = 4
61
+
62
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
63
+ super(Bottleneck, self).__init__()
64
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
65
+ self.bn1 = nn.BatchNorm2d(planes)
66
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
67
+ padding=1, bias=False)
68
+ self.bn2 = nn.BatchNorm2d(planes)
69
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
70
+ self.bn3 = nn.BatchNorm2d(planes * 4)
71
+ self.relu = nn.ReLU(inplace=True)
72
+ self.downsample = downsample
73
+ self.stride = stride
74
+
75
+ def forward(self, x):
76
+ residual = x
77
+
78
+ out = self.conv1(x)
79
+ out = self.bn1(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv2(out)
83
+ out = self.bn2(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv3(out)
87
+ out = self.bn3(out)
88
+
89
+ if self.downsample is not None:
90
+ residual = self.downsample(x)
91
+
92
+ out += residual
93
+ out = self.relu(out)
94
+
95
+ return out
96
+
97
+
98
+ class HourGlass(nn.Module):
99
+ def __init__(self, num_modules, depth, num_features):
100
+ super(HourGlass, self).__init__()
101
+ self.num_modules = num_modules
102
+ self.depth = depth
103
+ self.features = num_features
104
+
105
+ self._generate_network(self.depth)
106
+
107
+ def _generate_network(self, level):
108
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
109
+
110
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
111
+
112
+ if level > 1:
113
+ self._generate_network(level - 1)
114
+ else:
115
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
116
+
117
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
118
+
119
+ def _forward(self, level, inp):
120
+ # Upper branch
121
+ up1 = inp
122
+ up1 = self._modules['b1_' + str(level)](up1)
123
+
124
+ # Lower branch
125
+ low1 = F.avg_pool2d(inp, 2, stride=2)
126
+ low1 = self._modules['b2_' + str(level)](low1)
127
+
128
+ if level > 1:
129
+ low2 = self._forward(level - 1, low1)
130
+ else:
131
+ low2 = low1
132
+ low2 = self._modules['b2_plus_' + str(level)](low2)
133
+
134
+ low3 = low2
135
+ low3 = self._modules['b3_' + str(level)](low3)
136
+
137
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
138
+
139
+ return up1 + up2
140
+
141
+ def forward(self, x):
142
+ return self._forward(self.depth, x)
143
+
144
+
145
+ class FAN(nn.Module):
146
+
147
+ def __init__(self, num_modules=1):
148
+ super(FAN, self).__init__()
149
+ self.num_modules = num_modules
150
+
151
+ # Base part
152
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
153
+ self.bn1 = nn.BatchNorm2d(64)
154
+ self.conv2 = ConvBlock(64, 128)
155
+ self.conv3 = ConvBlock(128, 128)
156
+ self.conv4 = ConvBlock(128, 256)
157
+
158
+ # Stacking part
159
+ for hg_module in range(self.num_modules):
160
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
161
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
162
+ self.add_module('conv_last' + str(hg_module),
163
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
164
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
165
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
166
+ 68, kernel_size=1, stride=1, padding=0))
167
+
168
+ if hg_module < self.num_modules - 1:
169
+ self.add_module(
170
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
171
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
172
+ 256, kernel_size=1, stride=1, padding=0))
173
+
174
+ def forward(self, x):
175
+ x = F.relu(self.bn1(self.conv1(x)), True)
176
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
177
+ x = self.conv3(x)
178
+ x = self.conv4(x)
179
+
180
+ previous = x
181
+
182
+ outputs = []
183
+ for i in range(self.num_modules):
184
+ hg = self._modules['m' + str(i)](previous)
185
+
186
+ ll = hg
187
+ ll = self._modules['top_m_' + str(i)](ll)
188
+
189
+ ll = F.relu(self._modules['bn_end' + str(i)]
190
+ (self._modules['conv_last' + str(i)](ll)), True)
191
+
192
+ # Predict heatmaps
193
+ tmp_out = self._modules['l' + str(i)](ll)
194
+ outputs.append(tmp_out)
195
+
196
+ if i < self.num_modules - 1:
197
+ ll = self._modules['bl' + str(i)](ll)
198
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
199
+ previous = previous + ll + tmp_out_
200
+
201
+ return outputs
202
+
203
+
204
+ class ResNetDepth(nn.Module):
205
+
206
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
207
+ self.inplanes = 64
208
+ super(ResNetDepth, self).__init__()
209
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
210
+ bias=False)
211
+ self.bn1 = nn.BatchNorm2d(64)
212
+ self.relu = nn.ReLU(inplace=True)
213
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
214
+ self.layer1 = self._make_layer(block, 64, layers[0])
215
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
216
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
217
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
218
+ self.avgpool = nn.AvgPool2d(7)
219
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
220
+
221
+ for m in self.modules():
222
+ if isinstance(m, nn.Conv2d):
223
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
224
+ m.weight.data.normal_(0, math.sqrt(2. / n))
225
+ elif isinstance(m, nn.BatchNorm2d):
226
+ m.weight.data.fill_(1)
227
+ m.bias.data.zero_()
228
+
229
+ def _make_layer(self, block, planes, blocks, stride=1):
230
+ downsample = None
231
+ if stride != 1 or self.inplanes != planes * block.expansion:
232
+ downsample = nn.Sequential(
233
+ nn.Conv2d(self.inplanes, planes * block.expansion,
234
+ kernel_size=1, stride=stride, bias=False),
235
+ nn.BatchNorm2d(planes * block.expansion),
236
+ )
237
+
238
+ layers = []
239
+ layers.append(block(self.inplanes, planes, stride, downsample))
240
+ self.inplanes = planes * block.expansion
241
+ for i in range(1, blocks):
242
+ layers.append(block(self.inplanes, planes))
243
+
244
+ return nn.Sequential(*layers)
245
+
246
+ def forward(self, x):
247
+ x = self.conv1(x)
248
+ x = self.bn1(x)
249
+ x = self.relu(x)
250
+ x = self.maxpool(x)
251
+
252
+ x = self.layer1(x)
253
+ x = self.layer2(x)
254
+ x = self.layer3(x)
255
+ x = self.layer4(x)
256
+
257
+ x = self.avgpool(x)
258
+ x = x.view(x.size(0), -1)
259
+ x = self.fc(x)
260
+
261
+ return x
face_detection/utils.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import os
3
+ import sys
4
+ import time
5
+ import torch
6
+ import math
7
+ import numpy as np
8
+ import cv2
9
+
10
+
11
+ def _gaussian(
12
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
13
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
14
+ mean_vert=0.5):
15
+ # handle some defaults
16
+ if width is None:
17
+ width = size
18
+ if height is None:
19
+ height = size
20
+ if sigma_horz is None:
21
+ sigma_horz = sigma
22
+ if sigma_vert is None:
23
+ sigma_vert = sigma
24
+ center_x = mean_horz * width + 0.5
25
+ center_y = mean_vert * height + 0.5
26
+ gauss = np.empty((height, width), dtype=np.float32)
27
+ # generate kernel
28
+ for i in range(height):
29
+ for j in range(width):
30
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
31
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
32
+ if normalize:
33
+ gauss = gauss / np.sum(gauss)
34
+ return gauss
35
+
36
+
37
+ def draw_gaussian(image, point, sigma):
38
+ # Check if the gaussian is inside
39
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
40
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
41
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
42
+ return image
43
+ size = 6 * sigma + 1
44
+ g = _gaussian(size)
45
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
46
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
47
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
48
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
49
+ assert (g_x[0] > 0 and g_y[1] > 0)
50
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
51
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
52
+ image[image > 1] = 1
53
+ return image
54
+
55
+
56
+ def transform(point, center, scale, resolution, invert=False):
57
+ """Generate and affine transformation matrix.
58
+
59
+ Given a set of points, a center, a scale and a targer resolution, the
60
+ function generates and affine transformation matrix. If invert is ``True``
61
+ it will produce the inverse transformation.
62
+
63
+ Arguments:
64
+ point {torch.tensor} -- the input 2D point
65
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
66
+ scale {float} -- the scale of the face/object
67
+ resolution {float} -- the output resolution
68
+
69
+ Keyword Arguments:
70
+ invert {bool} -- define wherever the function should produce the direct or the
71
+ inverse transformation matrix (default: {False})
72
+ """
73
+ _pt = torch.ones(3)
74
+ _pt[0] = point[0]
75
+ _pt[1] = point[1]
76
+
77
+ h = 200.0 * scale
78
+ t = torch.eye(3)
79
+ t[0, 0] = resolution / h
80
+ t[1, 1] = resolution / h
81
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
82
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
83
+
84
+ if invert:
85
+ t = torch.inverse(t)
86
+
87
+ new_point = (torch.matmul(t, _pt))[0:2]
88
+
89
+ return new_point.int()
90
+
91
+
92
+ def crop(image, center, scale, resolution=256.0):
93
+ """Center crops an image or set of heatmaps
94
+
95
+ Arguments:
96
+ image {numpy.array} -- an rgb image
97
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
98
+ scale {float} -- scale of the face
99
+
100
+ Keyword Arguments:
101
+ resolution {float} -- the size of the output cropped image (default: {256.0})
102
+
103
+ Returns:
104
+ [type] -- [description]
105
+ """ # Crop around the center point
106
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
107
+ ul = transform([1, 1], center, scale, resolution, True)
108
+ br = transform([resolution, resolution], center, scale, resolution, True)
109
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
110
+ if image.ndim > 2:
111
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
112
+ image.shape[2]], dtype=np.int32)
113
+ newImg = np.zeros(newDim, dtype=np.uint8)
114
+ else:
115
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
116
+ newImg = np.zeros(newDim, dtype=np.uint8)
117
+ ht = image.shape[0]
118
+ wd = image.shape[1]
119
+ newX = np.array(
120
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
121
+ newY = np.array(
122
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
123
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
124
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
125
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
126
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
127
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
128
+ interpolation=cv2.INTER_LINEAR)
129
+ return newImg
130
+
131
+
132
+ def get_preds_fromhm(hm, center=None, scale=None):
133
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
134
+ and the scale is provided the function will return the points also in
135
+ the original coordinate frame.
136
+
137
+ Arguments:
138
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
139
+
140
+ Keyword Arguments:
141
+ center {torch.tensor} -- the center of the bounding box (default: {None})
142
+ scale {float} -- face scale (default: {None})
143
+ """
144
+ max, idx = torch.max(
145
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
146
+ idx += 1
147
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
148
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
149
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
150
+
151
+ for i in range(preds.size(0)):
152
+ for j in range(preds.size(1)):
153
+ hm_ = hm[i, j, :]
154
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
155
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
156
+ diff = torch.FloatTensor(
157
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
158
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
159
+ preds[i, j].add_(diff.sign_().mul_(.25))
160
+
161
+ preds.add_(-.5)
162
+
163
+ preds_orig = torch.zeros(preds.size())
164
+ if center is not None and scale is not None:
165
+ for i in range(hm.size(0)):
166
+ for j in range(hm.size(1)):
167
+ preds_orig[i, j] = transform(
168
+ preds[i, j], center, scale, hm.size(2), True)
169
+
170
+ return preds, preds_orig
171
+
172
+ def get_preds_fromhm_batch(hm, centers=None, scales=None):
173
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
174
+ and the scales is provided the function will return the points also in
175
+ the original coordinate frame.
176
+
177
+ Arguments:
178
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
179
+
180
+ Keyword Arguments:
181
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
182
+ scales {float} -- face scales (default: {None})
183
+ """
184
+ max, idx = torch.max(
185
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
186
+ idx += 1
187
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
188
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
189
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
190
+
191
+ for i in range(preds.size(0)):
192
+ for j in range(preds.size(1)):
193
+ hm_ = hm[i, j, :]
194
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
195
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
196
+ diff = torch.FloatTensor(
197
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
198
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
199
+ preds[i, j].add_(diff.sign_().mul_(.25))
200
+
201
+ preds.add_(-.5)
202
+
203
+ preds_orig = torch.zeros(preds.size())
204
+ if centers is not None and scales is not None:
205
+ for i in range(hm.size(0)):
206
+ for j in range(hm.size(1)):
207
+ preds_orig[i, j] = transform(
208
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
209
+
210
+ return preds, preds_orig
211
+
212
+ def shuffle_lr(parts, pairs=None):
213
+ """Shuffle the points left-right according to the axis of symmetry
214
+ of the object.
215
+
216
+ Arguments:
217
+ parts {torch.tensor} -- a 3D or 4D object containing the
218
+ heatmaps.
219
+
220
+ Keyword Arguments:
221
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
222
+ """
223
+ if pairs is None:
224
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
225
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
226
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
227
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
228
+ 62, 61, 60, 67, 66, 65]
229
+ if parts.ndimension() == 3:
230
+ parts = parts[pairs, ...]
231
+ else:
232
+ parts = parts[:, pairs, ...]
233
+
234
+ return parts
235
+
236
+
237
+ def flip(tensor, is_label=False):
238
+ """Flip an image or a set of heatmaps left-right
239
+
240
+ Arguments:
241
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
242
+
243
+ Keyword Arguments:
244
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
245
+ """
246
+ if not torch.is_tensor(tensor):
247
+ tensor = torch.from_numpy(tensor)
248
+
249
+ if is_label:
250
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
251
+ else:
252
+ tensor = tensor.flip(tensor.ndimension() - 1)
253
+
254
+ return tensor
255
+
256
+ # From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
257
+
258
+
259
+ def appdata_dir(appname=None, roaming=False):
260
+ """ appdata_dir(appname=None, roaming=False)
261
+
262
+ Get the path to the application directory, where applications are allowed
263
+ to write user specific files (e.g. configurations). For non-user specific
264
+ data, consider using common_appdata_dir().
265
+ If appname is given, a subdir is appended (and created if necessary).
266
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
267
+ """
268
+
269
+ # Define default user directory
270
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
271
+ if userDir is None:
272
+ userDir = os.path.expanduser('~')
273
+ if not os.path.isdir(userDir): # pragma: no cover
274
+ userDir = '/var/tmp' # issue #54
275
+
276
+ # Get system app data dir
277
+ path = None
278
+ if sys.platform.startswith('win'):
279
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
280
+ path = (path2 or path1) if roaming else (path1 or path2)
281
+ elif sys.platform.startswith('darwin'):
282
+ path = os.path.join(userDir, 'Library', 'Application Support')
283
+ # On Linux and as fallback
284
+ if not (path and os.path.isdir(path)):
285
+ path = userDir
286
+
287
+ # Maybe we should store things local to the executable (in case of a
288
+ # portable distro or a frozen application that wants to be portable)
289
+ prefix = sys.prefix
290
+ if getattr(sys, 'frozen', None):
291
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
292
+ for reldir in ('settings', '../settings'):
293
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
294
+ if os.path.isdir(localpath): # pragma: no cover
295
+ try:
296
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
297
+ os.remove(os.path.join(localpath, 'test.write'))
298
+ except IOError:
299
+ pass # We cannot write in this directory
300
+ else:
301
+ path = localpath
302
+ break
303
+
304
+ # Get path specific for this app
305
+ if appname:
306
+ if path == userDir:
307
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
308
+ path = os.path.join(path, appname)
309
+ if not os.path.isdir(path): # pragma: no cover
310
+ os.mkdir(path)
311
+
312
+ # Done
313
+ return path
inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import subprocess
4
+ from pathlib import Path
5
+
6
+ from config import hparams as hp
7
+ from nota_wav2lip import Wav2LipModelComparisonDemo
8
+
9
+ LRS_ORIGINAL_URL = os.getenv('LRS_ORIGINAL_URL', None)
10
+ LRS_COMPRESSED_URL = os.getenv('LRS_COMPRESSED_URL', None)
11
+
12
+ if not Path(hp.inference.model.wav2lip.checkpoint).exists() and LRS_ORIGINAL_URL is not None:
13
+ subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.wav2lip.checkpoint} {LRS_ORIGINAL_URL}", shell=True)
14
+ if not Path(hp.inference.model.nota_wav2lip.checkpoint).exists() and LRS_COMPRESSED_URL is not None:
15
+ subprocess.call(f"wget --no-check-certificate -O {hp.inference.model.nota_wav2lip.checkpoint} {LRS_COMPRESSED_URL}", shell=True)
16
+
17
+ def parse_args():
18
+
19
+ parser = argparse.ArgumentParser(description="NotaWav2Lip: Inference snippet for your own video and audio pair")
20
+
21
+ parser.add_argument(
22
+ '-a',
23
+ '--audio-input',
24
+ type=str,
25
+ required=True,
26
+ help="Path of the audio file"
27
+ )
28
+
29
+ parser.add_argument(
30
+ '-v',
31
+ '--video-frame-input',
32
+ type=str,
33
+ required=True,
34
+ help="Input directory with face image sequence. We recommend to extract the face image sequence with `preprocess.py`."
35
+ )
36
+
37
+ parser.add_argument(
38
+ '-b',
39
+ '--bbox-input',
40
+ type=str,
41
+ help="Path of the file with bbox coordinates. We recommend to extract the json file with `preprocess.py`."
42
+ "If None, it pretends that the json file is located at the same directory with face images: {VIDEO_FRAME_INPUT}.with_suffix('.json')."
43
+ )
44
+
45
+ parser.add_argument(
46
+ '-m',
47
+ '--model',
48
+ choices=['wav2lip', 'nota_wav2lip'],
49
+ default='nota_wav2ilp',
50
+ help="Model for generating talking video. Defaults: nota_wav2lip"
51
+ )
52
+
53
+ parser.add_argument(
54
+ '-o',
55
+ '--output-dir',
56
+ type=str,
57
+ default="result",
58
+ help="Output directory to save the result. Defaults: result"
59
+ )
60
+
61
+ parser.add_argument(
62
+ '-d',
63
+ '--device',
64
+ choices=['cpu', 'cuda'],
65
+ default='cpu',
66
+ help="Device setting for model inference. Defaults: cpu"
67
+ )
68
+
69
+ args = parser.parse_args()
70
+
71
+ return args
72
+
73
+ if __name__ == "__main__":
74
+ args = parse_args()
75
+ bbox_input = args.bbox_input if args.bbox_input is not None \
76
+ else Path(args.video_frame_input).with_suffix('.json')
77
+
78
+ servicer = Wav2LipModelComparisonDemo(device=args.device, result_dir=args.output_dir, model_list=args.model)
79
+ servicer.update_audio(args.audio_input, name='a0')
80
+ servicer.update_video(args.video_frame_input, bbox_input, name='v0')
81
+
82
+ servicer.save_as_video('a0', 'v0', args.model)
inference.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Original Wav2Lip
2
+ python inference.py\
3
+ -a "sample_video_lrs3/sxnlvwprf_c-00007.wav"\
4
+ -v "sample_video_lrs3/Li4-1yyrsTI-00010"\
5
+ -m "wav2lip"\
6
+ -o "result_original"\
7
+ --device cpu
8
+
9
+ # Nota's Wav2Lip (28× Compressed)
10
+ python inference.py\
11
+ -a "sample_video_lrs3/sxnlvwprf_c-00007.wav"\
12
+ -v "sample_video_lrs3/Li4-1yyrsTI-00010"\
13
+ -m "nota_wav2lip"\
14
+ -o "result_nota"\
15
+ --device cpu
nota_wav2lip/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from nota_wav2lip.demo import Wav2LipModelComparisonDemo
2
+ from nota_wav2lip.gradio import Wav2LipModelComparisonGradio
nota_wav2lip/audio.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+
7
+ from config import hparams
8
+
9
+ hp = hparams.audio
10
+
11
+ def load_wav(path, sr):
12
+ return librosa.core.load(path, sr=sr)[0]
13
+
14
+ def save_wav(wav, path, sr):
15
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
16
+ #proposed by @dsmiller
17
+ wavfile.write(path, sr, wav.astype(np.int16))
18
+
19
+ def save_wavenet_wav(wav, path, sr):
20
+ librosa.output.write_wav(path, wav, sr=sr)
21
+
22
+ def preemphasis(wav, k, preemphasize=True):
23
+ if preemphasize:
24
+ return signal.lfilter([1, -k], [1], wav)
25
+ return wav
26
+
27
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
28
+ if inv_preemphasize:
29
+ return signal.lfilter([1], [1, -k], wav)
30
+ return wav
31
+
32
+ def get_hop_size():
33
+ hop_size = hp.hop_size
34
+ if hop_size is None:
35
+ assert hp.frame_shift_ms is not None
36
+ hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
37
+ return hop_size
38
+
39
+ def linearspectrogram(wav):
40
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
41
+ S = _amp_to_db(np.abs(D)) - hp.ref_level_db
42
+
43
+ if hp.signal_normalization:
44
+ return _normalize(S)
45
+ return S
46
+
47
+ def melspectrogram(wav):
48
+ D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
49
+ S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
50
+
51
+ if hp.signal_normalization:
52
+ return _normalize(S)
53
+ return S
54
+
55
+ def _lws_processor():
56
+ import lws
57
+ return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
58
+
59
+ def _stft(y):
60
+ if hp.use_lws:
61
+ return _lws_processor(hp).stft(y).T
62
+ else:
63
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
64
+
65
+ ##########################################################
66
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
67
+ def num_frames(length, fsize, fshift):
68
+ """Compute number of time frames of spectrogram
69
+ """
70
+ pad = (fsize - fshift)
71
+ M = (length + pad * 2 - fsize) // fshift + 1 if length % fshift == 0 else (length + pad * 2 - fsize) // fshift + 2
72
+ return M
73
+
74
+
75
+ def pad_lr(x, fsize, fshift):
76
+ """Compute left and right padding
77
+ """
78
+ M = num_frames(len(x), fsize, fshift)
79
+ pad = (fsize - fshift)
80
+ T = len(x) + 2 * pad
81
+ r = (M - 1) * fshift + fsize - T
82
+ return pad, pad + r
83
+ ##########################################################
84
+ #Librosa correct padding
85
+ def librosa_pad_lr(x, fsize, fshift):
86
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
87
+
88
+ # Conversions
89
+ _mel_basis = None
90
+
91
+ def _linear_to_mel(spectogram):
92
+ global _mel_basis
93
+ if _mel_basis is None:
94
+ _mel_basis = _build_mel_basis()
95
+ return np.dot(_mel_basis, spectogram)
96
+
97
+ def _build_mel_basis():
98
+ assert hp.fmax <= hp.sample_rate // 2
99
+ return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels,
100
+ fmin=hp.fmin, fmax=hp.fmax)
101
+
102
+ def _amp_to_db(x):
103
+ min_level = np.exp(hp.min_level_db / 20 * np.log(10))
104
+ return 20 * np.log10(np.maximum(min_level, x))
105
+
106
+ def _db_to_amp(x):
107
+ return np.power(10.0, (x) * 0.05)
108
+
109
+ def _normalize(S):
110
+ if hp.allow_clipping_in_normalization:
111
+ if hp.symmetric_mels:
112
+ return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
113
+ -hp.max_abs_value, hp.max_abs_value)
114
+ else:
115
+ return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
116
+
117
+ assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
118
+ if hp.symmetric_mels:
119
+ return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
120
+ else:
121
+ return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
122
+
123
+ def _denormalize(D):
124
+ if hp.allow_clipping_in_normalization:
125
+ if hp.symmetric_mels:
126
+ return (((np.clip(D, -hp.max_abs_value,
127
+ hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
128
+ + hp.min_level_db)
129
+ else:
130
+ return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
131
+
132
+ if hp.symmetric_mels:
133
+ return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
134
+ else:
135
+ return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
nota_wav2lip/demo.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import platform
3
+ import subprocess
4
+ import time
5
+ from pathlib import Path
6
+ from typing import Dict, Iterator, List, Literal, Optional, Union
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from config import hparams as hp
12
+ from nota_wav2lip.inference import Wav2LipInferenceImpl
13
+ from nota_wav2lip.util import FFMPEG_LOGGING_MODE
14
+ from nota_wav2lip.video import AudioSlicer, VideoSlicer
15
+
16
+
17
+ class Wav2LipModelComparisonDemo:
18
+ def __init__(self, device='cpu', result_dir='./temp', model_list: Optional[Union[str, List[str]]]=None):
19
+ if model_list is None:
20
+ model_list: List[str] = ['wav2lip', 'nota_wav2lip']
21
+ if isinstance(model_list, str) and len(model_list) != 0:
22
+ model_list: List[str] = [model_list]
23
+ super().__init__()
24
+ self.video_dict: Dict[str, VideoSlicer] = {}
25
+ self.audio_dict: Dict[str, AudioSlicer] = {}
26
+
27
+ self.model_zoo: Dict[str, Wav2LipInferenceImpl] = {}
28
+ for model_name in model_list:
29
+ assert model_name in hp.inference.model, f"{model_name} not in hp.inference_model: {hp.inference.model}"
30
+ self.model_zoo[model_name] = Wav2LipInferenceImpl(
31
+ model_name, hp_inference_model=hp.inference.model[model_name], device=device
32
+ )
33
+
34
+ self._params_zoo: Dict[str, str] = {
35
+ model_name: self.model_zoo[model_name].params for model_name in self.model_zoo
36
+ }
37
+
38
+ self.result_dir: Path = Path(result_dir)
39
+ self.result_dir.mkdir(exist_ok=True)
40
+
41
+ @property
42
+ def params(self):
43
+ return self._params_zoo
44
+
45
+ def _infer(
46
+ self,
47
+ audio_name: str,
48
+ video_name: str,
49
+ model_type: Literal['wav2lip', 'nota_wav2lip']
50
+ ) -> Iterator[np.ndarray]:
51
+ audio_iterable: AudioSlicer = self.audio_dict[audio_name]
52
+ video_iterable: VideoSlicer = self.video_dict[video_name]
53
+ target_model = self.model_zoo[model_type]
54
+ return target_model.inference_with_iterator(audio_iterable, video_iterable)
55
+
56
+ def update_audio(self, audio_path, name=None):
57
+ _name = name if name is not None else Path(audio_path).stem
58
+ self.audio_dict.update(
59
+ {_name: AudioSlicer(audio_path)}
60
+ )
61
+
62
+ def update_video(self, frame_dir_path, bbox_path, name=None):
63
+ _name = name if name is not None else Path(frame_dir_path).stem
64
+ self.video_dict.update(
65
+ {_name: VideoSlicer(frame_dir_path, bbox_path)}
66
+ )
67
+
68
+ def save_as_video(self, audio_name, video_name, model_type):
69
+
70
+ output_video_path = self.result_dir / 'generated_with_audio.mp4'
71
+ frame_only_video_path = self.result_dir / 'generated.mp4'
72
+ audio_path = self.audio_dict[audio_name].audio_path
73
+
74
+ out = cv2.VideoWriter(str(frame_only_video_path),
75
+ cv2.VideoWriter_fourcc(*'mp4v'),
76
+ hp.face.video_fps,
77
+ (hp.inference.frame.w, hp.inference.frame.h))
78
+ start = time.time()
79
+ for frame in self._infer(audio_name=audio_name, video_name=video_name, model_type=model_type):
80
+ out.write(frame)
81
+ inference_time = time.time() - start
82
+ out.release()
83
+
84
+ command = f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {audio_path} -i {frame_only_video_path} -strict -2 -q:v 1 {output_video_path}"
85
+ subprocess.call(command, shell=platform.system() != 'Windows')
86
+
87
+ # The number of frames of generated video
88
+ video_frames_num = len(self.audio_dict[audio_name])
89
+ inference_fps = video_frames_num / inference_time
90
+
91
+ return output_video_path, inference_time, inference_fps
nota_wav2lip/gradio.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ from pathlib import Path
3
+
4
+ from nota_wav2lip.demo import Wav2LipModelComparisonDemo
5
+
6
+
7
+ class Wav2LipModelComparisonGradio(Wav2LipModelComparisonDemo):
8
+ def __init__(
9
+ self,
10
+ device='cpu',
11
+ result_dir='./temp',
12
+ video_label_dict=None,
13
+ audio_label_list=None,
14
+ default_video='v1',
15
+ default_audio='a1'
16
+ ) -> None:
17
+ if audio_label_list is None:
18
+ audio_label_list = {}
19
+ if video_label_dict is None:
20
+ video_label_dict = {}
21
+ super().__init__(device, result_dir)
22
+ self._video_label_dict = {k: Path(v).with_suffix('.mp4') for k, v in video_label_dict.items()}
23
+ self._audio_label_dict = audio_label_list
24
+ self._default_video = default_video
25
+ self._default_audio = default_audio
26
+
27
+ self._lock = threading.Lock() # lock for asserting that concurrency_count == 1
28
+
29
+ def _is_valid_input(self, video_selection, audio_selection):
30
+ assert video_selection in self._video_label_dict, \
31
+ f"Your input ({video_selection}) is not in {self._video_label_dict}!!!"
32
+ assert audio_selection in self._audio_label_dict, \
33
+ f"Your input ({audio_selection}) is not in {self._audio_label_dict}!!!"
34
+
35
+ def generate_original_model(self, video_selection, audio_selection):
36
+ try:
37
+ self._is_valid_input(video_selection, audio_selection)
38
+
39
+ with self._lock:
40
+ output_video_path, inference_time, inference_fps = \
41
+ self.save_as_video(audio_name=audio_selection,
42
+ video_name=video_selection,
43
+ model_type='wav2lip')
44
+
45
+ return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f")
46
+ except KeyboardInterrupt:
47
+ exit()
48
+ except Exception as e:
49
+ print(e)
50
+ pass
51
+
52
+ def generate_compressed_model(self, video_selection, audio_selection):
53
+ try:
54
+ self._is_valid_input(video_selection, audio_selection)
55
+
56
+ with self._lock:
57
+ output_video_path, inference_time, inference_fps = \
58
+ self.save_as_video(audio_name=audio_selection,
59
+ video_name=video_selection,
60
+ model_type='nota_wav2lip')
61
+
62
+ return str(output_video_path), format(inference_time, ".2f"), format(inference_fps, ".1f")
63
+ except KeyboardInterrupt:
64
+ exit()
65
+ except Exception as e:
66
+ print(e)
67
+ pass
68
+
69
+ def switch_video_samples(self, video_selection):
70
+ try:
71
+ if video_selection not in self._video_label_dict:
72
+ return self._video_label_dict[self._default_video]
73
+ return self._video_label_dict[video_selection]
74
+
75
+ except KeyboardInterrupt:
76
+ exit()
77
+ except Exception as e:
78
+ print(e)
79
+ pass
80
+
81
+ def switch_audio_samples(self, audio_selection):
82
+ try:
83
+ if audio_selection not in self._audio_label_dict:
84
+ return self._audio_label_dict[self._default_audio]
85
+ return self._audio_label_dict[audio_selection]
86
+
87
+ except KeyboardInterrupt:
88
+ exit()
89
+ except Exception as e:
90
+ print(e)
91
+ pass
nota_wav2lip/inference.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, Iterator, List, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from omegaconf import DictConfig
8
+ from tqdm import tqdm
9
+
10
+ from config import hparams as hp
11
+ from nota_wav2lip.models.util import count_params, load_model
12
+
13
+
14
+ class Wav2LipInferenceImpl:
15
+ def __init__(self, model_name: str, hp_inference_model: DictConfig, device='cpu'):
16
+ self.model: nn.Module = load_model(
17
+ model_name,
18
+ device=device,
19
+ **hp_inference_model
20
+ )
21
+ self.device = device
22
+ self._params: str = self._format_param(count_params(self.model))
23
+
24
+ @property
25
+ def params(self):
26
+ return self._params
27
+
28
+ @staticmethod
29
+ def _format_param(num_params: int) -> str:
30
+ params_in_million = num_params / 1e6
31
+ return f"{params_in_million:.1f}M"
32
+
33
+ @staticmethod
34
+ def _reset_batch() -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray], List[List[int]]]:
35
+ return [], [], [], []
36
+
37
+ def get_data_iterator(
38
+ self,
39
+ audio_iterable: Iterable[np.ndarray],
40
+ video_iterable: List[Tuple[np.ndarray, List[int]]]
41
+ ) -> Iterator[Tuple[np.ndarray, np.ndarray, np.ndarray, List[int]]]:
42
+ img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()
43
+
44
+ for i, m in enumerate(audio_iterable):
45
+ idx = i % len(video_iterable)
46
+ _frame_to_save, coords = video_iterable[idx]
47
+ frame_to_save = _frame_to_save.copy()
48
+ face = frame_to_save[coords[0]:coords[1], coords[2]:coords[3]].copy()
49
+
50
+ face: np.ndarray = cv2.resize(face, (hp.face.img_size, hp.face.img_size))
51
+
52
+ img_batch.append(face)
53
+ mel_batch.append(m)
54
+ frame_batch.append(frame_to_save)
55
+ coords_batch.append(coords)
56
+
57
+ if len(img_batch) >= hp.inference.batch_size:
58
+ img_batch = np.asarray(img_batch)
59
+ mel_batch = np.asarray(mel_batch)
60
+
61
+ img_masked = img_batch.copy()
62
+ img_masked[:, hp.face.img_size // 2:] = 0
63
+
64
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
65
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
66
+
67
+ yield img_batch, mel_batch, frame_batch, coords_batch
68
+ img_batch, mel_batch, frame_batch, coords_batch = self._reset_batch()
69
+
70
+ if len(img_batch) > 0:
71
+ img_batch = np.asarray(img_batch)
72
+ mel_batch = np.asarray(mel_batch)
73
+
74
+ img_masked = img_batch.copy()
75
+ img_masked[:, hp.face.img_size // 2:] = 0
76
+
77
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
78
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
79
+
80
+ yield img_batch, mel_batch, frame_batch, coords_batch
81
+
82
+ @torch.no_grad()
83
+ def inference_with_iterator(
84
+ self,
85
+ audio_iterable: Iterable[np.ndarray],
86
+ video_iterable: List[Tuple[np.ndarray, List[int]]]
87
+ ) -> Iterator[np.ndarray]:
88
+ data_iterator = self.get_data_iterator(audio_iterable, video_iterable)
89
+
90
+ for (img_batch, mel_batch, frames, coords) in \
91
+ tqdm(data_iterator, total=int(np.ceil(float(len(audio_iterable)) / hp.inference.batch_size))):
92
+
93
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.device)
94
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.device)
95
+
96
+ preds: torch.Tensor = self.forward(mel_batch, img_batch)
97
+
98
+ preds = preds.cpu().numpy().transpose(0, 2, 3, 1) * 255.
99
+ for pred, frame, coord in zip(preds, frames, coords):
100
+ y1, y2, x1, x2 = coord
101
+ pred = cv2.resize(pred.astype(np.uint8), (x2 - x1, y2 - y1))
102
+
103
+ frame[y1:y2, x1:x2] = pred
104
+ yield frame
105
+
106
+ @torch.no_grad()
107
+ def forward(self, audio_sequences: torch.Tensor, face_sequences: torch.Tensor) -> torch.Tensor:
108
+ return self.model(audio_sequences, face_sequences)
109
+
110
+ def __call__(self, *args, **kwargs):
111
+ return self.forward(*args, **kwargs)
nota_wav2lip/models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .base import Wav2LipBase
2
+ from .wav2lip import Wav2Lip
3
+ from .wav2lip_compressed import NotaWav2Lip
nota_wav2lip/models/base.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import final
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class Wav2LipBase(nn.Module):
8
+ def __init__(self) -> None:
9
+ super().__init__()
10
+
11
+ self.audio_encoder = nn.Sequential()
12
+ self.face_encoder_blocks = nn.ModuleList([])
13
+ self.face_decoder_blocks = nn.ModuleList([])
14
+ self.output_block = nn.Sequential()
15
+
16
+ @final
17
+ def forward(self, audio_sequences, face_sequences):
18
+ # audio_sequences = (B, T, 1, 80, 16)
19
+ B = audio_sequences.size(0)
20
+
21
+ input_dim_size = len(face_sequences.size())
22
+ if input_dim_size > 4:
23
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
24
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
25
+
26
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
27
+
28
+ feats = []
29
+ x = face_sequences
30
+ for f in self.face_encoder_blocks:
31
+ x = f(x)
32
+ feats.append(x)
33
+
34
+ x = audio_embedding
35
+ for f in self.face_decoder_blocks:
36
+ x = f(x)
37
+ try:
38
+ x = torch.cat((x, feats[-1]), dim=1)
39
+ except Exception as e:
40
+ print(x.size())
41
+ print(feats[-1].size())
42
+ raise e
43
+
44
+ feats.pop()
45
+
46
+ x = self.output_block(x)
47
+
48
+ if input_dim_size > 4:
49
+ x = torch.split(x, B, dim=0) # [(B, C, H, W)]
50
+ outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
51
+
52
+ else:
53
+ outputs = x
54
+
55
+ return outputs
nota_wav2lip/models/conv.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ class Conv2d(nn.Module):
7
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
8
+ super().__init__(*args, **kwargs)
9
+ self.conv_block = nn.Sequential(
10
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
11
+ nn.BatchNorm2d(cout)
12
+ )
13
+ self.act = nn.ReLU()
14
+ self.residual = residual
15
+
16
+ def forward(self, x):
17
+ out = self.conv_block(x)
18
+ if self.residual:
19
+ out += x
20
+ return self.act(out)
21
+
22
+
23
+ class Conv2dTranspose(nn.Module):
24
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ self.conv_block = nn.Sequential(
27
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
28
+ nn.BatchNorm2d(cout)
29
+ )
30
+ self.act = nn.ReLU()
31
+
32
+ def forward(self, x):
33
+ out = self.conv_block(x)
34
+ return self.act(out)
nota_wav2lip/models/util.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Type
2
+
3
+ import torch
4
+
5
+ from nota_wav2lip.models import NotaWav2Lip, Wav2Lip, Wav2LipBase
6
+
7
+ MODEL_REGISTRY: Dict[str, Type[Wav2LipBase]] = {
8
+ 'wav2lip': Wav2Lip,
9
+ 'nota_wav2lip': NotaWav2Lip
10
+ }
11
+
12
+ def _load(checkpoint_path, device):
13
+ assert device in ['cpu', 'cuda']
14
+
15
+ print(f"Load checkpoint from: {checkpoint_path}")
16
+ if device == 'cuda':
17
+ return torch.load(checkpoint_path)
18
+ return torch.load(checkpoint_path, map_location=lambda storage, _: storage)
19
+
20
+ def load_model(model_name: str, device, checkpoint, **kwargs) -> Wav2LipBase:
21
+
22
+ cls = MODEL_REGISTRY[model_name.lower()]
23
+ assert issubclass(cls, Wav2LipBase)
24
+
25
+ model = cls(**kwargs)
26
+ checkpoint = _load(checkpoint, device)
27
+ model.load_state_dict(checkpoint)
28
+ model = model.to(device)
29
+ return model.eval()
30
+
31
+ def count_params(model):
32
+ return sum(p.numel() for p in model.parameters())
nota_wav2lip/models/wav2lip.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from nota_wav2lip.models.base import Wav2LipBase
5
+ from nota_wav2lip.models.conv import Conv2d, Conv2dTranspose
6
+
7
+
8
+ class Wav2Lip(Wav2LipBase):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ self.face_encoder_blocks = nn.ModuleList([
13
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
14
+
15
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
16
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
17
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
18
+
19
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
20
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
21
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
22
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
23
+
24
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
25
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
26
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
27
+
28
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
29
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
30
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
31
+
32
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
33
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
34
+
35
+ nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
36
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
37
+
38
+ self.audio_encoder = nn.Sequential(
39
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
40
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
41
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
42
+
43
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
44
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
45
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
46
+
47
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
48
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
49
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
50
+
51
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
52
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
53
+
54
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
55
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
56
+
57
+ self.face_decoder_blocks = nn.ModuleList([
58
+ nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
59
+
60
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
61
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
62
+
63
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
64
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
65
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
66
+
67
+ nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
68
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
69
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
70
+
71
+ nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
72
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
73
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
74
+
75
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
76
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
77
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
78
+
79
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
80
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
81
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
82
+
83
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
84
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
85
+ nn.Sigmoid())
nota_wav2lip/models/wav2lip_compressed.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from nota_wav2lip.models.base import Wav2LipBase
5
+ from nota_wav2lip.models.conv import Conv2d, Conv2dTranspose
6
+
7
+
8
+ class NotaWav2Lip(Wav2LipBase):
9
+ def __init__(self, nef=4, naf=8, ndf=8, x_size=96):
10
+ super().__init__()
11
+
12
+ assert x_size in [96, 128]
13
+ self.ker_sz_last = x_size // 32
14
+
15
+ self.face_encoder_blocks = nn.ModuleList([
16
+ nn.Sequential(Conv2d(6, nef, kernel_size=7, stride=1, padding=3)), # 96,96
17
+
18
+ nn.Sequential(Conv2d(nef, nef * 2, kernel_size=3, stride=2, padding=1),), # 48,48
19
+
20
+ nn.Sequential(Conv2d(nef * 2, nef * 4, kernel_size=3, stride=2, padding=1),), # 24,24
21
+
22
+ nn.Sequential(Conv2d(nef * 4, nef * 8, kernel_size=3, stride=2, padding=1),), # 12,12
23
+
24
+ nn.Sequential(Conv2d(nef * 8, nef * 16, kernel_size=3, stride=2, padding=1),), # 6,6
25
+
26
+ nn.Sequential(Conv2d(nef * 16, nef * 32, kernel_size=3, stride=2, padding=1),), # 3,3
27
+
28
+ nn.Sequential(Conv2d(nef * 32, nef * 32, kernel_size=self.ker_sz_last, stride=1, padding=0), # 1, 1
29
+ Conv2d(nef * 32, nef * 32, kernel_size=1, stride=1, padding=0)), ])
30
+
31
+ self.audio_encoder = nn.Sequential(
32
+ Conv2d(1, naf, kernel_size=3, stride=1, padding=1),
33
+
34
+ Conv2d(naf, naf * 2, kernel_size=3, stride=(3, 1), padding=1),
35
+
36
+ Conv2d(naf * 2, naf * 4, kernel_size=3, stride=3, padding=1),
37
+
38
+ Conv2d(naf * 4, naf * 8, kernel_size=3, stride=(3, 2), padding=1),
39
+
40
+ Conv2d(naf * 8, naf * 16, kernel_size=3, stride=1, padding=0),
41
+ Conv2d(naf * 16, naf * 16, kernel_size=1, stride=1, padding=0), )
42
+
43
+ self.face_decoder_blocks = nn.ModuleList([
44
+ nn.Sequential(Conv2d(naf * 16, naf * 16, kernel_size=1, stride=1, padding=0), ),
45
+
46
+ nn.Sequential(Conv2dTranspose(nef * 32 + naf * 16, ndf * 16, kernel_size=self.ker_sz_last, stride=1, padding=0),),
47
+ # 3,3 # 512+512 = 1024
48
+
49
+ nn.Sequential(
50
+ Conv2dTranspose(nef * 32 + ndf * 16, ndf * 16, kernel_size=3, stride=2, padding=1, output_padding=1),), # 6, 6
51
+ # 512+512 = 1024
52
+
53
+ nn.Sequential(
54
+ Conv2dTranspose(nef * 16 + ndf * 16, ndf * 12, kernel_size=3, stride=2, padding=1, output_padding=1),), # 12, 12
55
+ # 256+512 = 768
56
+
57
+ nn.Sequential(
58
+ Conv2dTranspose(nef * 8 + ndf * 12, ndf * 8, kernel_size=3, stride=2, padding=1, output_padding=1),), # 24, 24
59
+ # 128+384 = 512
60
+
61
+ nn.Sequential(
62
+ Conv2dTranspose(nef * 4 + ndf * 8, ndf * 4, kernel_size=3, stride=2, padding=1, output_padding=1),), # 48, 48
63
+ # 64+256 = 320
64
+
65
+ nn.Sequential(
66
+ Conv2dTranspose(nef * 2 + ndf * 4, ndf * 2, kernel_size=3, stride=2, padding=1, output_padding=1),), # 96,96
67
+ # 32+128 = 160
68
+ ])
69
+
70
+ self.output_block = nn.Sequential(Conv2d(nef + ndf * 2, ndf, kernel_size=3, stride=1, padding=1), # 16+64 = 80
71
+ nn.Conv2d(ndf, 3, kernel_size=1, stride=1, padding=0),
72
+ nn.Sigmoid())
nota_wav2lip/preprocess/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from nota_wav2lip.preprocess.core import get_preprocessed_data
2
+ from nota_wav2lip.preprocess.lrs3_download import get_cropped_face_from_lrs3_label
nota_wav2lip/preprocess/core.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import platform
3
+ import subprocess
4
+ from pathlib import Path
5
+
6
+ import cv2
7
+ import numpy as np
8
+ from loguru import logger
9
+ from tqdm import tqdm
10
+
11
+ import face_detection
12
+ from nota_wav2lip.util import FFMPEG_LOGGING_MODE
13
+
14
+ detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, device='cpu')
15
+ PADDING = [0, 10, 0, 0]
16
+
17
+
18
+ def get_smoothened_boxes(boxes, T):
19
+ for i in range(len(boxes)):
20
+ window = boxes[len(boxes) - T:] if i + T > len(boxes) else boxes[i:i + T]
21
+ boxes[i] = np.mean(window, axis=0)
22
+ return boxes
23
+
24
+
25
+ def face_detect(images, pads, no_smooth=False, batch_size=1):
26
+
27
+ predictions = []
28
+ images_array = [cv2.imread(str(image)) for image in images]
29
+ for i in tqdm(range(0, len(images_array), batch_size)):
30
+ predictions.extend(detector.get_detections_for_batch(np.array(images_array[i:i + batch_size])))
31
+
32
+ results = []
33
+ pady1, pady2, padx1, padx2 = pads
34
+ for rect, image_array in zip(predictions, images_array):
35
+ if rect is None:
36
+ cv2.imwrite('temp/faulty_frame.jpg', image_array) # check this frame where the face was not detected.
37
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
38
+
39
+ y1 = max(0, rect[1] - pady1)
40
+ y2 = min(image_array.shape[0], rect[3] + pady2)
41
+ x1 = max(0, rect[0] - padx1)
42
+ x2 = min(image_array.shape[1], rect[2] + padx2)
43
+ results.append([x1, y1, x2, y2])
44
+
45
+ boxes = np.array(results)
46
+ bbox_format = "(y1, y2, x1, x2)"
47
+ if not no_smooth:
48
+ boxes = get_smoothened_boxes(boxes, T=5)
49
+ outputs = {
50
+ 'bbox': {str(image_path): tuple(map(int, (y1, y2, x1, x2))) for image_path, (x1, y1, x2, y2) in zip(images, boxes)},
51
+ 'format': bbox_format
52
+ }
53
+ return outputs
54
+
55
+
56
+ def save_video_frame(video_path, output_dir=None):
57
+ video_path = Path(video_path)
58
+ output_dir = output_dir if output_dir is not None else video_path.with_suffix('')
59
+ output_dir.mkdir(exist_ok=True)
60
+ return subprocess.call(
61
+ f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -r 25 -f image2 {output_dir}/%05d.jpg",
62
+ shell=platform.system() != 'Windows'
63
+ )
64
+
65
+
66
+ def save_audio_file(video_path, output_path=None):
67
+ video_path = Path(video_path)
68
+ output_path = output_path if output_path is not None else video_path.with_suffix('.wav')
69
+ subprocess.call(
70
+ f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -vn -acodec pcm_s16le -ar 16000 -ac 1 {output_path}",
71
+ shell=platform.system() != 'Windows'
72
+ )
73
+
74
+
75
+ def save_bbox_file(video_path, bbox_dict, output_path=None):
76
+ video_path = Path(video_path)
77
+ output_path = output_path if output_path is not None else video_path.with_suffix('.json')
78
+
79
+ with open(output_path, 'w') as f:
80
+ json.dump(bbox_dict, f, indent=4)
81
+
82
+ def get_preprocessed_data(video_path: Path):
83
+ video_path = Path(video_path)
84
+
85
+ image_sequence_dir = video_path.with_suffix('')
86
+ audio_path = video_path.with_suffix('.wav')
87
+ face_bbox_json_path = video_path.with_suffix('.json')
88
+
89
+ logger.info(f"Save 25 FPS video frames as image files ... will be saved at {video_path}")
90
+ save_video_frame(video_path=video_path, output_dir=image_sequence_dir)
91
+
92
+ logger.info(f"Save the audio as wav file ... will be saved at {audio_path}")
93
+ save_audio_file(video_path=video_path, output_path=audio_path) # bonus
94
+
95
+ # Load images, extract bboxes and save the coords(to directly use as array indicies)
96
+ logger.info(f"Extract face boxes and save the coords with json format ... will be saved at {face_bbox_json_path}")
97
+ results = face_detect(sorted(image_sequence_dir.glob("*.jpg")), pads=PADDING)
98
+ save_bbox_file(video_path, results, output_path=face_bbox_json_path)
nota_wav2lip/preprocess/ffmpeg.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FFMPEG_LOGGING_MODE = {
2
+ 'DEBUG': "",
3
+ 'INFO': "-v quiet -stats",
4
+ 'ERROR': "-hide_banner -loglevel error",
5
+ }
nota_wav2lip/preprocess/lrs3_download.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import subprocess
3
+ from pathlib import Path
4
+ from typing import Dict, List, Tuple, TypedDict, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import yt_dlp
9
+ from loguru import logger
10
+ from tqdm import tqdm
11
+
12
+ from nota_wav2lip.util import FFMPEG_LOGGING_MODE
13
+
14
+
15
+ class LabelInfo(TypedDict):
16
+ text: str
17
+ conf: int
18
+ url: str
19
+ bbox_xywhn: Dict[int, Tuple[float, float, float, float]]
20
+
21
+ def frame_to_time(frame_id: int, fps=25) -> str:
22
+ seconds = frame_id / fps
23
+
24
+ hours = int(seconds // 3600)
25
+ seconds -= 3600 * hours
26
+
27
+ minutes = int(seconds // 60)
28
+ seconds -= 60 * minutes
29
+
30
+ seconds_int = int(seconds)
31
+ seconds_milli = int((seconds - int(seconds)) * 1e3)
32
+
33
+ return f"{hours:02d}:{minutes:02d}:{seconds_int:02d}.{seconds_milli:03d}" # HH:MM:SS.mmm
34
+
35
+ def save_audio_file(input_path, start_frame_id, to_frame_id, output_path=None):
36
+ input_path = Path(input_path)
37
+ output_path = output_path if output_path is not None else input_path.with_suffix('.wav')
38
+
39
+ ss = frame_to_time(start_frame_id)
40
+ to = frame_to_time(to_frame_id)
41
+ subprocess.call(
42
+ f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {input_path} -vn -acodec pcm_s16le -ss {ss} -to {to} -ar 16000 -ac 1 {output_path}",
43
+ shell=platform.system() != 'Windows'
44
+ )
45
+
46
+ def merge_video_audio(video_path, audio_path, output_path):
47
+ subprocess.call(
48
+ f"ffmpeg {FFMPEG_LOGGING_MODE['ERROR']} -y -i {video_path} -i {audio_path} -strict experimental {output_path}",
49
+ shell=platform.system() != 'Windows'
50
+ )
51
+
52
+ def parse_lrs3_label(label_path) -> LabelInfo:
53
+ label_text = Path(label_path).read_text()
54
+ label_splitted = label_text.split('\n')
55
+
56
+ # Label validation
57
+ assert label_splitted[0].startswith("Text:")
58
+ assert label_splitted[1].startswith("Conf:")
59
+ assert label_splitted[2].startswith("Ref:")
60
+ assert label_splitted[4].startswith("FRAME")
61
+
62
+ label_info = LabelInfo(bbox_xywhn={})
63
+ label_info['text'] = label_splitted[0][len("Text: "):].strip()
64
+ label_info['conf'] = int(label_splitted[1][len("Conf: "):])
65
+ label_info['url'] = label_splitted[2][len("Ref: "):].strip()
66
+
67
+ for label_line in label_splitted[5:]:
68
+ bbox_splitted = [x.strip() for x in label_line.split('\t')]
69
+ if len(bbox_splitted) != 5:
70
+ continue
71
+ frame_index = int(bbox_splitted[0])
72
+ bbox_xywhn = tuple(map(float, bbox_splitted[1:]))
73
+ label_info['bbox_xywhn'][frame_index] = bbox_xywhn
74
+
75
+ return label_info
76
+
77
+ def _get_cropped_bbox(bbox_info_xywhn, original_width, original_height):
78
+
79
+ bbox_info = bbox_info_xywhn
80
+ x = bbox_info[0] * original_width
81
+ y = bbox_info[1] * original_height
82
+ w = bbox_info[2] * original_width
83
+ h = bbox_info[3] * original_height
84
+
85
+ x_min = max(0, int(x - 0.5 * w))
86
+ y_min = max(0, int(y))
87
+ x_max = min(original_width, int(x + 1.5 * w))
88
+ y_max = min(original_height, int(y + 1.5 * h))
89
+
90
+ cropped_width = x_max - x_min
91
+ cropped_height = y_max - y_min
92
+
93
+ if cropped_height > cropped_width:
94
+ offset = cropped_height - cropped_width
95
+ offset_low = min(x_min, offset // 2)
96
+ offset_high = min(offset - offset_low, original_width - x_max)
97
+ x_min -= offset_low
98
+ x_max += offset_high
99
+ else:
100
+ offset = cropped_width - cropped_height
101
+ offset_low = min(y_min, offset // 2)
102
+ offset_high = min(offset - offset_low, original_width - y_max)
103
+ y_min -= offset_low
104
+ y_max += offset_high
105
+
106
+ return x_min, y_min, x_max, y_max
107
+
108
+ def _get_smoothened_boxes(bbox_dict, bbox_smoothen_window):
109
+ boxes = [np.array(bbox_dict[frame_id]) for frame_id in sorted(bbox_dict)]
110
+ for i in range(len(boxes)):
111
+ window = boxes[len(boxes) - bbox_smoothen_window:] if i + bbox_smoothen_window > len(boxes) else boxes[i:i + bbox_smoothen_window]
112
+ boxes[i] = np.mean(window, axis=0)
113
+
114
+ for idx, frame_id in enumerate(sorted(bbox_dict)):
115
+ bbox_dict[frame_id] = (np.rint(boxes[idx])).astype(int).tolist()
116
+ return bbox_dict
117
+
118
+ def download_video_from_youtube(youtube_ref, output_path):
119
+ ydl_url = f"https://www.youtube.com/watch?v={youtube_ref}"
120
+ ydl_opts = {
121
+ 'format': 'bestvideo[ext=mp4][height<=720]+bestaudio[ext=m4a]/best[ext=mp4][height<=720]',
122
+ 'outtmpl': str(output_path),
123
+ }
124
+
125
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
126
+ ydl.download([ydl_url])
127
+
128
+ def resample_video(input_path, output_path):
129
+ subprocess.call(
130
+ f"ffmpeg {FFMPEG_LOGGING_MODE['INFO']} -y -i {input_path} -r 25 -preset veryfast {output_path}",
131
+ shell=platform.system() != 'Windows'
132
+ )
133
+
134
+ def _get_smoothen_xyxy_bbox(
135
+ label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]],
136
+ original_width: int,
137
+ original_height: int,
138
+ bbox_smoothen_window: int = 5
139
+ ) -> Dict[int, Tuple[float, float, float, float]]:
140
+
141
+ label_bbox_xyxy: Dict[int, Tuple[float, float, float, float]] = {}
142
+ for frame_id in sorted(label_bbox_xywhn):
143
+ frame_bbox_xywhn = label_bbox_xywhn[frame_id]
144
+ bbox_xyxy = _get_cropped_bbox(frame_bbox_xywhn, original_width, original_height)
145
+ label_bbox_xyxy[frame_id] = bbox_xyxy
146
+
147
+ label_bbox_xyxy = _get_smoothened_boxes(label_bbox_xyxy, bbox_smoothen_window=bbox_smoothen_window)
148
+ return label_bbox_xyxy
149
+
150
+ def get_start_end_frame_id(
151
+ label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]],
152
+ ) -> Tuple[int, int]:
153
+ frame_ids = list(label_bbox_xywhn.keys())
154
+ start_frame_id = min(frame_ids)
155
+ to_frame_id = max(frame_ids)
156
+ return start_frame_id, to_frame_id
157
+
158
+ def crop_video_with_bbox(
159
+ input_path,
160
+ label_bbox_xywhn: Dict[int, Tuple[float, float, float, float]],
161
+ start_frame_id,
162
+ to_frame_id,
163
+ output_path,
164
+ bbox_smoothen_window = 5,
165
+ frame_width = 224,
166
+ frame_height = 224,
167
+ fps = 25,
168
+ interpolation = cv2.INTER_CUBIC,
169
+ ):
170
+ def frame_generator(cap):
171
+ if not cap.isOpened():
172
+ raise IOError("Error: Could not open video.")
173
+
174
+ while True:
175
+ ret, frame = cap.read()
176
+ if not ret:
177
+ break
178
+ yield frame
179
+
180
+ cap.release()
181
+
182
+ cap = cv2.VideoCapture(str(input_path))
183
+ original_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
184
+ original_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
185
+ label_bbox_xyxy = _get_smoothen_xyxy_bbox(label_bbox_xywhn, original_width, original_height, bbox_smoothen_window=bbox_smoothen_window)
186
+
187
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
188
+ out = cv2.VideoWriter(str(output_path), fourcc, fps, (frame_width, frame_height))
189
+
190
+ for frame_id, frame in tqdm(enumerate(frame_generator(cap))):
191
+ if start_frame_id <= frame_id <= to_frame_id:
192
+ x_min, y_min, x_max, y_max = label_bbox_xyxy[frame_id]
193
+
194
+ frame_cropped = frame[y_min:y_max, x_min:x_max]
195
+ frame_cropped = cv2.resize(frame_cropped, (frame_width, frame_height), interpolation=interpolation)
196
+ out.write(frame_cropped)
197
+
198
+ out.release()
199
+
200
+
201
+ def get_cropped_face_from_lrs3_label(
202
+ label_text_path: Union[Path, str],
203
+ video_root_dir: Union[Path, str],
204
+ bbox_smoothen_window: int = 5,
205
+ frame_width: int = 224,
206
+ frame_height: int = 224,
207
+ fps: int = 25,
208
+ interpolation = cv2.INTER_CUBIC,
209
+ ignore_cache: bool = False,
210
+ ):
211
+ label_text_path = Path(label_text_path)
212
+ label_info = parse_lrs3_label(label_text_path)
213
+ start_frame_id, to_frame_id = get_start_end_frame_id(label_info['bbox_xywhn'])
214
+
215
+ video_root_dir = Path(video_root_dir)
216
+ video_cache_dir = video_root_dir / ".cache"
217
+ video_cache_dir.mkdir(parents=True, exist_ok=True)
218
+
219
+ output_video: Path = video_cache_dir / f"{label_info['url']}.mp4"
220
+ output_resampled_video: Path = output_video.with_name(f"{output_video.stem}-25fps.mp4")
221
+ output_cropped_audio: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.wav")
222
+ output_cropped_video: Path = output_video.with_name(f"{output_video.stem}-{label_text_path.stem}-cropped.mp4")
223
+ output_cropped_with_audio: Path = video_root_dir / output_video.with_name(f"{output_video.stem}-{label_text_path.stem}.mp4").name
224
+
225
+ if not output_video.exists() or ignore_cache:
226
+ youtube_ref = label_info['url']
227
+ logger.info(f"Download Youtube video(https://www.youtube.com/watch?v={youtube_ref}) ... will be saved at {output_video}")
228
+ download_video_from_youtube(youtube_ref, output_path=output_video)
229
+
230
+ if not output_resampled_video.exists() or ignore_cache:
231
+ logger.info(f"Resampling video to 25 FPS ... will be saved at {output_resampled_video}")
232
+ resample_video(input_path=output_video, output_path=output_resampled_video)
233
+
234
+ if not output_cropped_audio.exists() or ignore_cache:
235
+ logger.info(f"Cut audio file with the given timestamps ... will be saved at {output_cropped_audio}")
236
+ save_audio_file(
237
+ output_resampled_video,
238
+ start_frame_id=start_frame_id,
239
+ to_frame_id=to_frame_id,
240
+ output_path=output_cropped_audio
241
+ )
242
+
243
+ logger.info(f"Naive crop the face region with the given frame labels ... will be saved at {output_cropped_video}")
244
+ crop_video_with_bbox(
245
+ output_resampled_video,
246
+ label_info['bbox_xywhn'],
247
+ start_frame_id,
248
+ to_frame_id,
249
+ output_path=output_cropped_video,
250
+ bbox_smoothen_window=bbox_smoothen_window,
251
+ frame_width=frame_width,
252
+ frame_height=frame_height,
253
+ fps=fps,
254
+ interpolation=interpolation
255
+ )
256
+
257
+ if not output_cropped_with_audio.exists() or ignore_cache:
258
+ logger.info(f"Merge an audio track with the cropped face sequence ... will be saved at {output_cropped_with_audio}")
259
+ merge_video_audio(output_cropped_video, output_cropped_audio, output_cropped_with_audio)
nota_wav2lip/util.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ FFMPEG_LOGGING_MODE = {
2
+ 'DEBUG': "",
3
+ 'INFO': "-v quiet -stats",
4
+ 'ERROR': "-hide_banner -loglevel error",
5
+ }
nota_wav2lip/video.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import List, Tuple, Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ import nota_wav2lip.audio as audio
9
+ from config import hparams as hp
10
+
11
+
12
+ class VideoSlicer:
13
+ def __init__(self, frame_dir: Union[Path, str], bbox_path: Union[Path, str]):
14
+ self.fps = hp.face.video_fps
15
+ self.frame_dir = frame_dir
16
+ self.frame_path_list = sorted(Path(self.frame_dir).glob("*.jpg"))
17
+ self.frame_array_list: List[np.ndarray] = [cv2.imread(str(image)) for image in self.frame_path_list]
18
+
19
+ with open(bbox_path, 'r') as f:
20
+ metadata = json.load(f)
21
+ self.bbox: List[List[int]] = [metadata['bbox'][key] for key in sorted(metadata['bbox'].keys())]
22
+ self.bbox_format = metadata['format']
23
+ assert len(self.bbox) == len(self.frame_array_list)
24
+
25
+ def __len__(self):
26
+ return len(self.frame_array_list)
27
+
28
+ def __getitem__(self, idx) -> Tuple[np.ndarray, List[int]]:
29
+ bbox = self.bbox[idx]
30
+ frame_original: np.ndarray = self.frame_array_list[idx]
31
+ # return frame_original[bbox[0]:bbox[1], bbox[2]:bbox[3], :]
32
+ return frame_original, bbox
33
+
34
+
35
+ class AudioSlicer:
36
+ def __init__(self, audio_path: Union[Path, str]):
37
+ self.fps = hp.face.video_fps
38
+ self.mel_chunks = self._audio_chunk_generator(audio_path)
39
+ self._audio_path = audio_path
40
+
41
+ @property
42
+ def audio_path(self):
43
+ return self._audio_path
44
+
45
+ def __len__(self):
46
+ return len(self.mel_chunks)
47
+
48
+ def _audio_chunk_generator(self, audio_path):
49
+ wav: np.ndarray = audio.load_wav(audio_path, hp.audio.sample_rate)
50
+ mel: np.ndarray = audio.melspectrogram(wav)
51
+
52
+ if np.isnan(mel.reshape(-1)).sum() > 0:
53
+ raise ValueError('Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again')
54
+
55
+ mel_chunks: List[np.ndarray] = []
56
+ mel_idx_multiplier = 80. / self.fps
57
+
58
+ i = 0
59
+ while True:
60
+ start_idx = int(i * mel_idx_multiplier)
61
+ if start_idx + hp.face.mel_step_size > len(mel[0]):
62
+ mel_chunks.append(mel[:, len(mel[0]) - hp.face.mel_step_size:])
63
+ return mel_chunks
64
+ mel_chunks.append(mel[:, start_idx: start_idx + hp.face.mel_step_size])
65
+ i += 1
66
+
67
+ def __getitem__(self, idx: int) -> np.ndarray:
68
+ return self.mel_chunks[idx]