File size: 2,088 Bytes
b38913b
9f69156
b38913b
 
 
 
9f69156
 
 
 
b0f3abf
9f69156
 
b0f3abf
 
 
 
 
 
 
9f69156
 
b38913b
 
 
 
 
 
 
 
 
 
 
 
 
9a1761d
b0f3abf
b38913b
9a1761d
 
 
b38913b
 
 
 
 
9f69156
 
 
 
 
50eb146
b38913b
9a1761d
b38913b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from huggingface_hub import hf_hub_download
import subprocess
import tempfile
import shutil

def download_model(repo_id, model_name):
    model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
    return model_path

def run_inference(model_name, prompt_path):
    repo_id = "hpcai-tech/Open-Sora"
    
    config_mapping = {
        "OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py",
        "OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py",
        "OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py"
    }
    
    config_path = config_mapping[model_name]
    ckpt_path = download_model(repo_id, model_name)

    with open(config_path, 'r') as file:
        config_content = file.read()
    config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_path}"')
    
    with tempfile.NamedTemporaryFile('w', delete=False) as temp_file:
        temp_file.write(config_content)
        temp_config_path = temp_file.name

    cmd = [
        "torchrun", "--standalone", "--nproc_per_node", "1",
        "scripts/inference.py", temp_config_path,
        "--ckpt-path", ckpt_path
    ]
    subprocess.run(cmd, capture_output=True, text=True)
    shutil.rmtree(temp_file.name)

    # Assuming the output video is saved in a known path, e.g., "./output/video.mp4"
    output_video_path = "./output/video.mp4"
    return output_video_path

def main():
    gr.Interface(
        fn=run_inference,
        inputs=[
            gr.Dropdown(choices=[
                "OpenSora-v1-16x256x256.pth",
                "OpenSora-v1-HQ-16x256x256.pth",
                "OpenSora-v1-HQ-16x512x512.pth"
            ], label="Model Selection"),
            gr.Textbox(label="Prompt Path", value="./assets/texts/t2v_samples.txt")
        ],
        outputs=gr.Video(label="Output Video"),
        title="Open-Sora Inference",
        description="Run Open-Sora Inference with Custom Parameters"
    ).launch()

if __name__ == "__main__":
    main()