File size: 5,022 Bytes
98fed26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import sys
import shutil
import uuid
import subprocess
import gradio as gr
import shutil
from glob import glob
from huggingface_hub import snapshot_download, hf_hub_download
from moviepy.editor import VideoFileClip  # Import MoviePy

# Download models
os.makedirs("pretrained_weights", exist_ok=True)

# List of subdirectories to create inside "checkpoints"
subfolders = [
    "stable-video-diffusion-img2vid-xt"
]

# Create each subdirectory
for subfolder in subfolders:
    os.makedirs(os.path.join("pretrained_weights", subfolder), exist_ok=True)

snapshot_download(
    repo_id="stabilityai/stable-video-diffusion-img2vid",
    local_dir="./pretrained_weights/stable-video-diffusion-img2vid-xt"
)

snapshot_download(
    repo_id="Yhmeng1106/anidoc",
    local_dir="./pretrained_weights"
)

hf_hub_download(
    repo_id="facebook/cotracker",
    filename="cotracker2.pth",
    local_dir="./pretrained_weights"
)

def generate(control_sequence, ref_image):
    control_image = control_sequence  # "data_test/sample4.mp4"
    ref_image = ref_image  # "data_test/sample4.png"
    unique_id = str(uuid.uuid4())
    output_dir = f"results_{unique_id}"
    
    try:
        # Use MoviePy to get the number of frames in the control_sequence video
        video_clip = VideoFileClip(control_image)
        num_frames = int(video_clip.fps * video_clip.duration)  # Calculate total frames
        video_clip.close()  # Close the video clip to free resources

        # Run the inference command
        subprocess.run(
            [
                "python", "scripts_infer/anidoc_inference.py",
                "--all_sketch",
                "--matching",
                "--tracking",
                "--control_image", f"{control_image}",
                "--ref_image", f"{ref_image}",
                "--output_dir", f"{output_dir}",
                "--max_point", "10",
                "--num_frames", str(num_frames)  # Pass the calculated num_frames
            ],
            check=True
        )

        # Search for the mp4 file in a subfolder of output_dir
        output_video = glob(os.path.join(output_dir, "*.mp4"))
        print(output_video)
        
        if output_video:
            output_video_path = output_video[0]  # Get the first match
        else:
            output_video_path = None
        
        print(output_video_path)
        return output_video_path
    
    except subprocess.CalledProcessError as e:
        raise gr.Error(f"Error during inference: {str(e)}")
    except Exception as e:
        raise gr.Error(f"Error processing video: {str(e)}")

css = """
div#col-container{
    margin: 0 auto;
    max-width: 982px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("# AniDoc: Animation Creation Made Easier")
        gr.Markdown("AniDoc colorizes a sequence of sketches based on a character design reference with high fidelity, even when the sketches significantly differ in pose and scale.")
        gr.HTML("""
        <div style="display:flex;column-gap:4px;">
            <a href="https://github.com/yihao-meng/AniDoc">
                <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
            </a> 
            <a href="https://yihao-meng.github.io/AniDoc_demo/">
                <img src='https://img.shields.io/badge/Project-Page-green'>
            </a>
            <a href="https://arxiv.org/pdf/2412.14173">
                <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
            </a>
            <a href="https://huggingface.co/spaces/fffiloni/AniDoc?duplicate=true">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
            </a>
            <a href="https://huggingface.co/fffiloni">
                <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
            </a>
        </div>
        """)
        with gr.Row():
            with gr.Column():
                control_sequence = gr.Video(label="Control Sequence", format="mp4")
                ref_image = gr.Image(label="Reference Image", type="filepath")
                submit_btn = gr.Button("Submit")
            with gr.Column():
                video_result = gr.Video(label="Result")

                gr.Examples(
                    examples=[
                        ["data_test/sample1.mp4", "data_test/sample1.png"],
                        ["data_test/sample2.mp4", "data_test/sample2.png"],
                        ["data_test/sample3.mp4", "data_test/sample3.png"],
                        ["data_test/sample4.mp4", "data_test/sample4.png"]
                    ],
                    inputs=[control_sequence, ref_image]
                )

    submit_btn.click(
        fn=generate,
        inputs=[control_sequence, ref_image],
        outputs=[video_result]
    )

demo.queue().launch(show_api=False, show_error=True, share=True)