File size: 11,069 Bytes
90ee73b
 
 
3f9fd43
90ee73b
f9a691e
 
 
 
90ee73b
 
 
 
 
 
f9a691e
0a535f7
 
90ee73b
f3a1f2e
 
 
 
 
7da1ebd
337bc14
90ee73b
 
f3a1f2e
90ee73b
 
f3a1f2e
 
 
 
 
f9a691e
 
 
 
 
7e1bff8
 
 
 
 
 
 
 
 
 
b37b78f
 
7e1bff8
 
 
 
 
 
 
 
 
 
 
f9a691e
 
 
 
 
decd441
f9a691e
 
 
 
decd441
 
f9a691e
 
 
 
 
 
 
decd441
f9a691e
 
 
 
35d2d3f
f9a691e
 
decd441
f9a691e
 
 
 
 
 
decd441
f9a691e
 
decd441
f9a691e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a535f7
 
 
09f95d8
 
5114719
 
1514a70
0a535f7
90ee73b
f3a1f2e
87d5fe9
 
 
f3a1f2e
 
 
 
 
90ee73b
337bc14
 
1514a70
a67daee
1514a70
337bc14
 
7da1ebd
 
61902e5
092446c
 
61902e5
7da1ebd
 
 
61902e5
decd441
 
 
 
f9a691e
decd441
 
 
 
 
00c39cc
decd441
 
 
 
f9a691e
decd441
0a535f7
 
decd441
0a535f7
 
decd441
 
 
 
c5fccfb
decd441
 
 
 
 
0a535f7
7e1bff8
0a535f7
 
90ee73b
 
 
61902e5
0a535f7
 
 
7e1bff8
 
0a535f7
 
 
61902e5
0a535f7
 
90ee73b
 
87d5fe9
61902e5
87d5fe9
337bc14
f3a1f2e
 
 
 
 
 
f9a691e
f3a1f2e
092446c
 
 
 
 
f9a691e
092446c
 
 
 
 
 
f9a691e
092446c
337bc14
27f6e5d
337bc14
27f6e5d
a67daee
 
006c2e8
 
 
 
 
 
337bc14
1514a70
337bc14
f3a1f2e
 
87d5fe9
 
 
 
 
0a535f7
87d5fe9
f9a691e
 
 
0a535f7
 
 
90ee73b
 
 
f9a691e
0a535f7
90ee73b
 
0e4418a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import gradio as gr
import torch
import os
import base64
import uuid
import tempfile
import numpy as np
import cv2
import subprocess

from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image


SECRET_TOKEN = os.getenv('SECRET_TOKEN', 'default_secret')

# Constants
bases = {
    "ToonYou": "frankjoshua/toonyou_beta6",
    "epiCRealism": "emilianJR/epiCRealism"
}
step_loaded = None
base_loaded = "epiCRealism"
motion_loaded = None

# Ensure model and scheduler are initialized in GPU-enabled function
if not torch.cuda.is_available():
    raise NotImplementedError("No GPU detected!")

device = "cuda"
dtype = torch.float16
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")

# ----------------------------- VIDEO ENCODING ---------------------------------
# Unfortunately, the Hugging Face Diffusers utils hardcode MP4V as a codec,
# which is not supported by all browsers. This is a critical issue for AiTube,
# so we are forced to implement our own encoding algorithm.
# ------------------------------------------------------------------------------

def export_to_video_file(video_frames, output_video_path=None, fps=10):
    if output_video_path is None:
        output_video_path = tempfile.NamedTemporaryFile(suffix=".webm").name

    if isinstance(video_frames[0], np.ndarray):
        video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
    elif isinstance(video_frames[0], Image.Image):
        video_frames = [np.array(frame) for frame in video_frames]

    # Use VP9 codec
    fourcc = cv2.VideoWriter_fourcc(*'VP90')
    h, w, c = video_frames[0].shape
    video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (w, h), True)

    for frame in video_frames:
        # Ensure the video frame is in the correct color format
        img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        video_writer.write(img)
    video_writer.release()

    return output_video_path

# ----------------------------- FRAME INTERPOLATION ---------------------------------
# we cannot afford to use AI-based algorithms such as FILM or ST-MFNet,
# those are way too slow for a AiTube which needs things to be as fast as possible
# -----------------------------------------------------------------------------------

def interpolate_video_frames(input_file_path, output_file_path, output_fps=10, desired_duration=2):
    """
    Interpolates frames in a video file to adjust frame rate and duration using ffmpeg's minterpolate.
    
    Parameters:
        input_file_path (str): Path to the input video file.
        output_file_path (str): Path to the output video file.
        output_fps (int): Target frames per second for the output video.
        desired_duration (int): Desired duration of the video in seconds.
    
    Returns:
        str: The file path of the modified video.
    """
    # Calculate the input fps required to stretch the video to the desired duration
    input_fps = find_input_fps(input_file_path, desired_duration)
    
    # Construct the ffmpeg command for interpolation
    cmd = [
        'ffmpeg',
        '-i', input_file_path,  # input file
        '-filter:v', f'minterpolate=fps={output_fps}',  # minterpolate filter options
        '-r', str(output_fps),  # output frame rate
        output_file_path  # Output file
    ]
    
    # Execute the command
    try:
        subprocess.run(cmd, check=True)
        print("Video interpolation successful.")
        return input_file_path
    except subprocess.CalledProcessError as e:
        print("Failed to interpolate video. Error:", e)
        return output_file_path
    
def find_input_fps(file_path, desired_duration):
    """
    Determine the input fps that, when stretched to the desired duration, matches the original video length.
    
    Parameters:
        file_path (str): Path to the video file.
        desired_duration (int or float): Desired duration in seconds.
        
    Returns:
        float: Calculated input fps.
    """
    # FFprobe command to find the duration of the video
    ffprobe_cmd = [
        'ffprobe',
        '-v', 'error',
        '-show_entries', 'format=duration',
        '-of', 'default=noprint_wrappers=1:nokey=1',
        file_path
    ]
    
    try:
        result = subprocess.run(ffprobe_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        duration = float(result.stdout.strip())
        input_fps = duration / desired_duration
    except Exception as e:
        print("Failed to get video duration. Error:", e)
        input_fps = 10  # Assume a default value if unable to fetch duration
    
    return input_fps
    
def generate_image(secret_token, prompt, base, width, height, motion, step, desired_duration, desired_fps):
    if secret_token != SECRET_TOKEN:
        raise gr.Error(
            f'Invalid secret token. Please fork the original space if you want to use it for yourself.')


    global step_loaded
    global base_loaded
    global motion_loaded
    # print(prompt, base, step)

    if step_loaded != step:
        repo = "ByteDance/AnimateDiff-Lightning"
        ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
        step_loaded = step

    if base_loaded != base:
        pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
        base_loaded = base

    if motion_loaded != motion:
        pipe.unload_lora_weights()
        if motion != "":
            pipe.load_lora_weights(motion, adapter_name="motion")
            pipe.set_adapters(["motion"], [0.7])
        motion_loaded = motion

    output = pipe(
        prompt=prompt,

        width=width,
        height=height,
        
        guidance_scale=1.0,
        num_inference_steps=step,
    )
    
    video_uuid = str(uuid.uuid4()).replace("-", "")
    raw_video_path = f"/tmp/{video_uuid}_raw.webm"
    enhanced_video_path = f"/tmp/{video_uuid}_enhanced.webm"
    

    # note the fps is hardcoded, this is a limitation from AnimateDiff I think?
    # (could we change this?)
    #
    # maybe to make things faster, we could *not* encode the video (as this uses files and external processes, which can be slow)
    # and instead return the unencoded frames to the frontend renderer?
    raw_video_path = export_to_video_file(output.frames[0], raw_video_path, fps=10)

    final_video_path = raw_video_path
    
    # Optional frame interpolation
    if desired_duration != 2 or desired_fps != 10:
        final_video_path = interpolate_video_frames(raw_video_path, enhanced_video_path, output_fps=desired_fps, desired_duration=desired_duration)

    # Read the content of the video file and encode it to base64
    with open(final_video_path, "rb") as video_file:
        video_base64 = base64.b64encode(video_file.read()).decode('utf-8')

    # clean-up (otherwise there is always a risk of "ghosting", eg. someone seeing the previous generated video,
    # of one of the steps go wrong - also we need to absolutely delete videos as we generate random files,
    # we can't afford to get a "tmp disk full" error)
    try:
        os.remove(raw_video_path)
        if final_video_path != raw_video_path:
            os.remove(final_video_path)
    except Exception as e:
        print("Failed to delete a video path:", e)
    
    # Prepend the appropriate data URI header with MIME type
    video_data_uri = 'data:video/webm;base64,' + video_base64

    return video_data_uri


# Gradio Interface
with gr.Blocks() as demo:
    gr.HTML("""
        <div style="z-index: 100; position: fixed; top: 0px; right: 0px; left: 0px; bottom: 0px; width: 100%; height: 100%; background: white; display: flex; align-items: center; justify-content: center; color: black;">
        <div style="text-align: center; color: black;">
        <p style="color: black;">This space is a headless component of the cloud rendering engine used by AiTube.</p>
        <p style="color: black;">It is not available for public use, but you can use the <a href="https://huggingface.co/spaces/ByteDance/AnimateDiff-Lightning" target="_blank">original space</a>.</p>
        </div>
        </div>""")
    

    secret_token = gr.Text(label='Secret Token', max_lines=1)
    
    with gr.Group():
        with gr.Row():
            prompt = gr.Textbox(
                label='Prompt'
            )
        with gr.Row():
            select_base = gr.Dropdown(
                label='Base model',
                choices=[
                    "ToonYou", 
                    "epiCRealism",
                ],
                value=base_loaded
            )
            width = gr.Slider(
                label='Width',
                minimum=128,
                maximum=2048,
                step=32,
                value=512,
            )
            height = gr.Slider(
                label='Height',
                minimum=128,
                maximum=2048,
                step=32,
                value=256,
            )
            select_motion = gr.Dropdown(
                label='Motion',
                choices=[
                    ("Default", ""),
                    ("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
                    ("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
                    ("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
                    ("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
                    ("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
                    ("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
                    ("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
                    ("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
                ],
                value="",
            )
            select_step = gr.Dropdown(
                label='Inference steps',
                choices=[
                    ('1-Step', 1), 
                    ('2-Step', 2),
                    ('4-Step', 4),
                    ('8-Step', 8)],
                value=4,
            )
            duration_slider = gr.Slider(label="Desired Duration (seconds)", min_value=2, max_value=30, value=2, step=1)
            fps_slider = gr.Slider(label="Desired Frames Per Second", min_value=10, max_value=60, value=10, step=1)
    
            submit = gr.Button()

    output_video_base64 = gr.Text()

    submit.click(
        fn=generate_image,
        inputs=[secret_token, prompt, select_base, width, height, select_motion, select_step, duration_slider, fps_slider],
        outputs=output_video_base64,
    )

demo.queue(max_size=12).launch(show_api=True)