File size: 2,565 Bytes
b38913b
9f69156
b38913b
 
 
a3e2fa9
b38913b
9f69156
 
 
 
a3e2fa9
9f69156
 
a3e2fa9
b0f3abf
 
 
 
 
 
 
9f69156
 
a3e2fa9
 
 
 
 
 
b38913b
 
a3e2fa9
b38913b
a3e2fa9
b38913b
 
 
 
 
 
 
 
 
a3e2fa9
 
 
 
 
b38913b
a3e2fa9
 
 
 
b38913b
 
 
 
 
9f69156
 
 
 
 
a3e2fa9
b38913b
a3e2fa9
b38913b
a3e2fa9
 
b38913b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import gradio as gr
from huggingface_hub import hf_hub_download
import subprocess
import tempfile
import shutil
import os

def download_model(repo_id, model_name):
    model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
    return model_path

def run_inference(model_name, prompt_text):
    repo_id = "hpcai-tech/Open-Sora"
    
    # Map model names to their respective configuration files
    config_mapping = {
        "OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py",
        "OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py",
        "OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py"
    }
    
    config_path = config_mapping[model_name]
    ckpt_path = download_model(repo_id, model_name)

    # Save prompt_text to a temporary text file
    prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w')
    prompt_file.write(prompt_text)
    prompt_file.close()

    # Read and update the configuration file
    with open(config_path, 'r') as file:
        config_content = file.read()
    config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"')
    
    # Create a temporary file for the updated configuration
    with tempfile.NamedTemporaryFile('w', delete=False) as temp_file:
        temp_file.write(config_content)
        temp_config_path = temp_file.name

    cmd = [
        "torchrun", "--standalone", "--nproc_per_node", "1",
        "scripts/inference.py", temp_config_path,
        "--ckpt-path", ckpt_path
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)

    # Clean up the temporary files
    os.remove(temp_file.name)
    os.remove(prompt_file.name)

    if result.returncode == 0:
        return "Inference completed successfully.", result.stdout
    else:
        return "Error occurred:", result.stderr

def main():
    gr.Interface(
        fn=run_inference,
        inputs=[
            gr.Dropdown(choices=[
                "OpenSora-v1-16x256x256.pth",
                "OpenSora-v1-HQ-16x256x256.pth",
                "OpenSora-v1-HQ-16x512x512.pth"
            ], label="Model Selection"),
            gr.Textbox(label="Prompt Text", placeholder="Enter prompt text here")
        ],
        outputs="text",
        title="Open-Sora Inference",
        description="Run Open-Sora Inference with Custom Parameters",
        share=True  # Set to True to create a public link
    ).launch()

if __name__ == "__main__":
    main()