Jonny001 commited on
Commit
1a975b3
·
verified ·
1 Parent(s): 0ce7f46

Update roop/core.py

Browse files
Files changed (1) hide show
  1. roop/core.py +78 -74
roop/core.py CHANGED
@@ -2,17 +2,13 @@
2
 
3
  import os
4
  import sys
5
- # single thread doubles cuda performance - needs to be set before torch import
6
- if any(arg.startswith('--execution-provider') for arg in sys.argv):
7
- os.environ['OMP_NUM_THREADS'] = '1'
8
- # reduce tensorflow log level
9
- os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
10
- import warnings
11
- from typing import List
12
  import platform
13
  import signal
14
  import shutil
15
  import argparse
 
 
 
16
  import torch
17
  import onnxruntime
18
  import tensorflow
@@ -22,34 +18,43 @@ import roop.metadata
22
  import roop.ui as ui
23
  from roop.predicter import predict_image, predict_video
24
  from roop.processors.frame.core import get_frame_processors_modules
25
- from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
 
 
 
26
 
27
- if 'ROCMExecutionProvider' in roop.globals.execution_providers:
28
- del torch
 
 
29
 
30
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
31
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
32
 
33
 
34
  def parse_args() -> None:
 
35
  signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
36
- program = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100))
37
- program.add_argument('-s', '--source', help='select an source image', dest='source_path')
38
- program.add_argument('-t', '--target', help='select an target image or video', dest='target_path')
39
- program.add_argument('-o', '--output', help='select output file or directory', dest='output_path')
40
- program.add_argument('--frame-processor', help='frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
41
- program.add_argument('--keep-fps', help='keep original fps', dest='keep_fps', action='store_true', default=False)
42
- program.add_argument('--keep-audio', help='keep original audio', dest='keep_audio', action='store_true', default=True)
43
- program.add_argument('--keep-frames', help='keep temporary frames', dest='keep_frames', action='store_true', default=False)
44
- program.add_argument('--many-faces', help='process every face', dest='many_faces', action='store_true', default=False)
45
- program.add_argument('--video-encoder', help='adjust output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
46
- program.add_argument('--video-quality', help='adjust output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
47
- program.add_argument('--max-memory', help='maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
48
- program.add_argument('--execution-provider', help='available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
49
- program.add_argument('--execution-threads', help='number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
50
- program.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
51
-
52
- args = program.parse_args()
 
 
 
53
 
54
  roop.globals.source_path = args.source_path
55
  roop.globals.target_path = args.target_path
@@ -68,45 +73,50 @@ def parse_args() -> None:
68
 
69
 
70
  def encode_execution_providers(execution_providers: List[str]) -> List[str]:
71
- return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]
 
72
 
73
 
74
  def decode_execution_providers(execution_providers: List[str]) -> List[str]:
75
- return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
76
- if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]
 
77
 
78
 
79
  def suggest_max_memory() -> int:
 
80
  if platform.system().lower() == 'darwin':
81
  return 10
82
  return 14
83
 
84
 
85
  def suggest_execution_providers() -> List[str]:
 
86
  return encode_execution_providers(onnxruntime.get_available_providers())
87
 
88
 
89
  def suggest_execution_threads() -> int:
90
- if 'DmlExecutionProvider' in roop.globals.execution_providers:
91
- return 1
92
- if 'ROCMExecutionProvider' in roop.globals.execution_providers:
93
  return 1
94
  return 8
95
 
96
 
97
  def limit_resources() -> None:
98
- # prevent tensorflow memory leak
 
99
  gpus = tensorflow.config.experimental.list_physical_devices('GPU')
100
  for gpu in gpus:
101
  tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
102
  tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
103
  ])
104
- # limit memory usage
 
105
  if roop.globals.max_memory:
106
  memory = roop.globals.max_memory * 1024 ** 3
107
  if platform.system().lower() == 'darwin':
108
  memory = roop.globals.max_memory * 1024 ** 6
109
- if platform.system().lower() == 'windows':
110
  import ctypes
111
  kernel32 = ctypes.windll.kernel32
112
  kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
@@ -116,11 +126,13 @@ def limit_resources() -> None:
116
 
117
 
118
  def release_resources() -> None:
 
119
  if 'CUDAExecutionProvider' in roop.globals.execution_providers:
120
  torch.cuda.empty_cache()
121
 
122
 
123
  def pre_check() -> bool:
 
124
  if sys.version_info < (3, 9):
125
  update_status('Python version is not supported - please upgrade to 3.9 or higher.')
126
  return False
@@ -131,16 +143,19 @@ def pre_check() -> bool:
131
 
132
 
133
  def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
 
134
  print(f'[{scope}] {message}')
135
  if not roop.globals.headless:
136
  ui.update_status(message)
137
 
138
 
139
  def start() -> None:
 
140
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
141
  if not frame_processor.pre_start():
142
  return
143
- # process image to image
 
144
  if has_image_extension(roop.globals.target_path):
145
  if predict_image(roop.globals.target_path):
146
  destroy()
@@ -150,66 +165,55 @@ def start() -> None:
150
  frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
151
  frame_processor.post_process()
152
  release_resources()
153
- if is_image(roop.globals.target_path):
154
- update_status('Processing to image succeed!')
155
- else:
156
- update_status('Processing to image failed!')
157
  return
158
- # process image to videos
 
159
  if predict_video(roop.globals.target_path):
160
  destroy()
 
161
  update_status('Creating temp resources...')
162
  create_temp(roop.globals.target_path)
163
  update_status('Extracting frames...')
164
  extract_frames(roop.globals.target_path)
165
  temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
 
166
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
167
  update_status('Progressing...', frame_processor.NAME)
168
  frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
169
  frame_processor.post_process()
170
  release_resources()
171
- # handles fps
 
172
  if roop.globals.keep_fps:
173
- update_status('Detecting fps...')
174
  fps = detect_fps(roop.globals.target_path)
175
- update_status(f'Creating video with {fps} fps...')
176
  create_video(roop.globals.target_path, fps)
177
  else:
178
- update_status('Creating video with 30.0 fps...')
179
  create_video(roop.globals.target_path)
180
- # handle audio
 
181
  if roop.globals.keep_audio:
182
- if roop.globals.keep_fps:
183
- update_status('Restoring audio...')
184
- else:
185
- update_status('Restoring audio might cause issues as fps are not kept...')
186
- restore_audio(roop.globals.target_path, roop.globals.output_path)
187
- else:
188
- move_temp(roop.globals.target_path, roop.globals.output_path)
189
- # clean and validate
190
- clean_temp(roop.globals.target_path)
191
- if is_video(roop.globals.target_path):
192
- update_status('Processing to video succeed!')
193
- else:
194
- update_status('Processing to video failed!')
195
 
196
 
197
  def destroy() -> None:
198
- if roop.globals.target_path:
199
- clean_temp(roop.globals.target_path)
200
- quit()
 
201
 
202
 
203
- def run() -> None:
204
  parse_args()
205
- if not pre_check():
206
- return
207
- for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
208
- if not frame_processor.pre_check():
209
- return
210
- limit_resources()
211
- if roop.globals.headless:
212
  start()
213
- else:
214
- window = ui.init(start, destroy)
215
- window.mainloop()
 
2
 
3
  import os
4
  import sys
 
 
 
 
 
 
 
5
  import platform
6
  import signal
7
  import shutil
8
  import argparse
9
+ import warnings
10
+ from typing import List
11
+
12
  import torch
13
  import onnxruntime
14
  import tensorflow
 
18
  import roop.ui as ui
19
  from roop.predicter import predict_image, predict_video
20
  from roop.processors.frame.core import get_frame_processors_modules
21
+ from roop.utilities import (
22
+ has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames,
23
+ get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path
24
+ )
25
 
26
+ # Reduce TensorFlow log level and configure threading for torch
27
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
28
+ if any(arg.startswith('--execution-provider') for arg in sys.argv):
29
+ os.environ['OMP_NUM_THREADS'] = '1'
30
 
31
  warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
32
  warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
33
 
34
 
35
  def parse_args() -> None:
36
+ """Parse command-line arguments and configure global settings."""
37
  signal.signal(signal.SIGINT, lambda signal_number, frame: destroy())
38
+
39
+ parser = argparse.ArgumentParser(
40
+ formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=100)
41
+ )
42
+ parser.add_argument('-s', '--source', help='Path to the source image', dest='source_path')
43
+ parser.add_argument('-t', '--target', help='Path to the target image or video', dest='target_path')
44
+ parser.add_argument('-o', '--output', help='Path to the output file or directory', dest='output_path')
45
+ parser.add_argument('--frame-processor', help='Frame processors (choices: face_swapper, face_enhancer, ...)', dest='frame_processor', default=['face_swapper'], nargs='+')
46
+ parser.add_argument('--keep-fps', help='Keep original FPS', dest='keep_fps', action='store_true', default=False)
47
+ parser.add_argument('--keep-audio', help='Keep original audio', dest='keep_audio', action='store_true', default=True)
48
+ parser.add_argument('--keep-frames', help='Keep temporary frames', dest='keep_frames', action='store_true', default=False)
49
+ parser.add_argument('--many-faces', help='Process every face', dest='many_faces', action='store_true', default=False)
50
+ parser.add_argument('--video-encoder', help='Output video encoder', dest='video_encoder', default='libx264', choices=['libx264', 'libx265', 'libvpx-vp9'])
51
+ parser.add_argument('--video-quality', help='Output video quality', dest='video_quality', type=int, default=18, choices=range(52), metavar='[0-51]')
52
+ parser.add_argument('--max-memory', help='Maximum amount of RAM in GB', dest='max_memory', type=int, default=suggest_max_memory())
53
+ parser.add_argument('--execution-provider', help='Available execution provider (choices: cpu, ...)', dest='execution_provider', default=['cpu'], choices=suggest_execution_providers(), nargs='+')
54
+ parser.add_argument('--execution-threads', help='Number of execution threads', dest='execution_threads', type=int, default=suggest_execution_threads())
55
+ parser.add_argument('-v', '--version', action='version', version=f'{roop.metadata.name} {roop.metadata.version}')
56
+
57
+ args = parser.parse_args()
58
 
59
  roop.globals.source_path = args.source_path
60
  roop.globals.target_path = args.target_path
 
73
 
74
 
75
  def encode_execution_providers(execution_providers: List[str]) -> List[str]:
76
+ """Convert execution providers to their encoded form."""
77
+ return [provider.replace('ExecutionProvider', '').lower() for provider in execution_providers]
78
 
79
 
80
  def decode_execution_providers(execution_providers: List[str]) -> List[str]:
81
+ """Decode execution providers from their encoded form."""
82
+ return [provider for provider, encoded_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
83
+ if any(execution_provider in encoded_provider for execution_provider in execution_providers)]
84
 
85
 
86
  def suggest_max_memory() -> int:
87
+ """Suggest maximum memory in GB based on the operating system."""
88
  if platform.system().lower() == 'darwin':
89
  return 10
90
  return 14
91
 
92
 
93
  def suggest_execution_providers() -> List[str]:
94
+ """Suggest available execution providers based on ONNX Runtime."""
95
  return encode_execution_providers(onnxruntime.get_available_providers())
96
 
97
 
98
  def suggest_execution_threads() -> int:
99
+ """Suggest the number of execution threads based on execution providers."""
100
+ if 'DmlExecutionProvider' in roop.globals.execution_providers or 'ROCMExecutionProvider' in roop.globals.execution_providers:
 
101
  return 1
102
  return 8
103
 
104
 
105
  def limit_resources() -> None:
106
+ """Limit GPU and RAM resources based on configuration."""
107
+ # Prevent TensorFlow memory leak
108
  gpus = tensorflow.config.experimental.list_physical_devices('GPU')
109
  for gpu in gpus:
110
  tensorflow.config.experimental.set_virtual_device_configuration(gpu, [
111
  tensorflow.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)
112
  ])
113
+
114
+ # Limit memory usage
115
  if roop.globals.max_memory:
116
  memory = roop.globals.max_memory * 1024 ** 3
117
  if platform.system().lower() == 'darwin':
118
  memory = roop.globals.max_memory * 1024 ** 6
119
+ elif platform.system().lower() == 'windows':
120
  import ctypes
121
  kernel32 = ctypes.windll.kernel32
122
  kernel32.SetProcessWorkingSetSize(-1, ctypes.c_size_t(memory), ctypes.c_size_t(memory))
 
126
 
127
 
128
  def release_resources() -> None:
129
+ """Release resources such as GPU cache."""
130
  if 'CUDAExecutionProvider' in roop.globals.execution_providers:
131
  torch.cuda.empty_cache()
132
 
133
 
134
  def pre_check() -> bool:
135
+ """Perform preliminary checks before starting the processing."""
136
  if sys.version_info < (3, 9):
137
  update_status('Python version is not supported - please upgrade to 3.9 or higher.')
138
  return False
 
143
 
144
 
145
  def update_status(message: str, scope: str = 'ROOP.CORE') -> None:
146
+ """Update status message to the console or UI."""
147
  print(f'[{scope}] {message}')
148
  if not roop.globals.headless:
149
  ui.update_status(message)
150
 
151
 
152
  def start() -> None:
153
+ """Start the processing based on the configuration and input."""
154
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
155
  if not frame_processor.pre_start():
156
  return
157
+
158
+ # Process image to image
159
  if has_image_extension(roop.globals.target_path):
160
  if predict_image(roop.globals.target_path):
161
  destroy()
 
165
  frame_processor.process_image(roop.globals.source_path, roop.globals.output_path, roop.globals.output_path)
166
  frame_processor.post_process()
167
  release_resources()
168
+ update_status('Processing to image succeeded!' if is_image(roop.globals.target_path) else 'Processing to image failed!')
 
 
 
169
  return
170
+
171
+ # Process image to video
172
  if predict_video(roop.globals.target_path):
173
  destroy()
174
+
175
  update_status('Creating temp resources...')
176
  create_temp(roop.globals.target_path)
177
  update_status('Extracting frames...')
178
  extract_frames(roop.globals.target_path)
179
  temp_frame_paths = get_temp_frame_paths(roop.globals.target_path)
180
+
181
  for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
182
  update_status('Progressing...', frame_processor.NAME)
183
  frame_processor.process_video(roop.globals.source_path, temp_frame_paths)
184
  frame_processor.post_process()
185
  release_resources()
186
+
187
+ # Handle FPS
188
  if roop.globals.keep_fps:
189
+ update_status('Detecting FPS...')
190
  fps = detect_fps(roop.globals.target_path)
191
+ update_status(f'Creating video with {fps} FPS...')
192
  create_video(roop.globals.target_path, fps)
193
  else:
194
+ update_status('Creating video with 30.0 FPS...')
195
  create_video(roop.globals.target_path)
196
+
197
+ # Handle audio
198
  if roop.globals.keep_audio:
199
+ update_status('Restoring audio...' if roop.globals.keep_fps else 'Restoring audio and creating final video...')
200
+ restore_audio(roop.globals.target_path)
201
+
202
+ move_temp(roop.globals.target_path)
203
+ clean_temp()
204
+ update_status('Processing succeeded!')
205
+ release_resources()
 
 
 
 
 
 
206
 
207
 
208
  def destroy() -> None:
209
+ """Cleanup and exit the program."""
210
+ update_status('Cleaning up and exiting...')
211
+ clean_temp()
212
+ sys.exit()
213
 
214
 
215
+ if __name__ == '__main__':
216
  parse_args()
217
+ if pre_check():
218
+ limit_resources()
 
 
 
 
 
219
  start()