File size: 6,709 Bytes
d0ffe9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import logging
import subprocess
from math import ceil
from pathlib import Path
from typing import Annotated, Optional
import typer
from animatediff import get_dir
from .ffmpeg import FfmpegEncoder, VideoCodec, codec_extn
from .ncnn import RifeNCNNOptions
rife_dir = get_dir("data/rife")
rife_ncnn_vulkan = rife_dir.joinpath("rife-ncnn-vulkan")
logger = logging.getLogger(__name__)
app: typer.Typer = typer.Typer(
name="rife",
context_settings=dict(help_option_names=["-h", "--help"]),
rich_markup_mode="rich",
pretty_exceptions_show_locals=False,
help="RIFE motion flow interpolation (MORE FPS!)",
)
def rife_interpolate(
input_frames_dir:str,
output_frames_dir:str,
frame_multiplier:int = 2,
rife_model:str = "rife-v4.6",
spatial_tta:bool = False,
temporal_tta:bool = False,
uhd:bool = False,
):
rife_model_dir = rife_dir.joinpath(rife_model)
if not rife_model_dir.joinpath("flownet.bin").exists():
raise FileNotFoundError(f"RIFE model dir {rife_model_dir} does not have a model in it!")
rife_opts = RifeNCNNOptions(
model_path=rife_model_dir,
input_path=input_frames_dir,
output_path=output_frames_dir,
time_step=1 / frame_multiplier,
spatial_tta=spatial_tta,
temporal_tta=temporal_tta,
uhd=uhd,
)
rife_args = rife_opts.get_args(frame_multiplier=frame_multiplier)
# actually run RIFE
logger.info("Running RIFE, this may take a little while...")
with subprocess.Popen(
[rife_ncnn_vulkan, *rife_args], stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as proc:
errs = []
for line in proc.stderr:
line = line.decode("utf-8").strip()
if line:
logger.debug(line)
stdout, _ = proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"RIFE failed with code {proc.returncode}:\n" + "\n".join(errs))
import glob
import os
org_images = sorted(glob.glob( os.path.join(output_frames_dir, "[0-9]*.png"), recursive=False))
for o in org_images:
p = Path(o)
new_no = int(p.stem) - 1
new_p = p.with_stem(f"{new_no:08d}")
p.rename(new_p)
@app.command(no_args_is_help=True)
def interpolate(
rife_model: Annotated[
str,
typer.Option("--rife-model", "-m", help="RIFE model to use (subdirectory of data/rife/)"),
] = "rife-v4.6",
in_fps: Annotated[
int,
typer.Option("--in-fps", "-I", help="Input frame FPS (8 for AnimateDiff)", show_default=True),
] = 8,
frame_multiplier: Annotated[
int,
typer.Option(
"--frame-multiplier", "-M", help="Multiply total frame count by this", show_default=True
),
] = 8,
out_fps: Annotated[
int,
typer.Option("--out-fps", "-F", help="Target FPS", show_default=True),
] = 50,
codec: Annotated[
VideoCodec,
typer.Option("--codec", "-c", help="Output video codec", show_default=True),
] = VideoCodec.webm,
lossless: Annotated[
bool,
typer.Option("--lossless", "-L", is_flag=True, help="Use lossless encoding (WebP only)"),
] = False,
spatial_tta: Annotated[
bool,
typer.Option("--spatial-tta", "-x", is_flag=True, help="Enable RIFE Spatial TTA mode"),
] = False,
temporal_tta: Annotated[
bool,
typer.Option("--temporal-tta", "-z", is_flag=True, help="Enable RIFE Temporal TTA mode"),
] = False,
uhd: Annotated[
bool,
typer.Option("--uhd", "-u", is_flag=True, help="Enable RIFE UHD mode"),
] = False,
frames_dir: Annotated[
Path,
typer.Argument(path_type=Path, file_okay=False, exists=True, help="Path to source frames directory"),
] = ...,
out_file: Annotated[
Optional[Path],
typer.Argument(
dir_okay=False,
help="Path to output file (default: frames_dir/rife-output.<out_type>)",
show_default=False,
),
] = None,
):
rife_model_dir = rife_dir.joinpath(rife_model)
if not rife_model_dir.joinpath("flownet.bin").exists():
raise FileNotFoundError(f"RIFE model dir {rife_model_dir} does not have a model in it!")
if not frames_dir.exists():
raise FileNotFoundError(f"Frames directory {frames_dir} does not exist!")
# where to put the RIFE interpolated frames (default: frames_dir/../<frames_dir>-rife)
# TODO: make this configurable?
rife_frames_dir = frames_dir.parent.joinpath(f"{frames_dir.name}-rife")
rife_frames_dir.mkdir(exist_ok=True, parents=True)
# build output file path
file_extn = codec_extn(codec)
if out_file is None:
out_file = frames_dir.parent.joinpath(f"{frames_dir.name}-rife.{file_extn}")
elif out_file.suffix != file_extn:
logger.warn("Output file extension does not match codec, changing extension")
out_file = out_file.with_suffix(file_extn)
# build RIFE command and get args
# This doesn't need to be a Pydantic model tbh. It could just be a function/class.
rife_opts = RifeNCNNOptions(
model_path=rife_model_dir,
input_path=frames_dir,
output_path=rife_frames_dir,
time_step=1 / in_fps, # TODO: make this configurable?
spatial_tta=spatial_tta,
temporal_tta=temporal_tta,
uhd=uhd,
)
rife_args = rife_opts.get_args(frame_multiplier=frame_multiplier)
# actually run RIFE
logger.info("Running RIFE, this may take a little while...")
with subprocess.Popen(
[rife_ncnn_vulkan, *rife_args], stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as proc:
errs = []
for line in proc.stderr:
line = line.decode("utf-8").strip()
if line:
logger.debug(line)
stdout, _ = proc.communicate()
if proc.returncode != 0:
raise RuntimeError(f"RIFE failed with code {proc.returncode}:\n" + "\n".join(errs))
# now it is ffmpeg time
logger.info("Creating ffmpeg encoder...")
encoder = FfmpegEncoder(
frames_dir=rife_frames_dir,
out_file=out_file,
codec=codec,
in_fps=min(out_fps, in_fps * frame_multiplier),
out_fps=out_fps,
lossless=lossless,
)
logger.info("Encoding interpolated frames with ffmpeg...")
result = encoder.encode()
logger.debug(f"ffmpeg result: {result}")
logger.info(f"Find the RIFE frames at: {rife_frames_dir.absolute().relative_to(Path.cwd())}")
logger.info(f"Find the output file at: {out_file.absolute().relative_to(Path.cwd())}")
logger.info("Done!")
|