Spaces:
Runtime error
Runtime error
File size: 2,088 Bytes
b38913b 9f69156 b38913b 9f69156 b0f3abf 9f69156 b0f3abf 9f69156 b38913b 9a1761d b0f3abf b38913b 9a1761d b38913b 9f69156 50eb146 b38913b 9a1761d b38913b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from huggingface_hub import hf_hub_download
import subprocess
import tempfile
import shutil
def download_model(repo_id, model_name):
model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
return model_path
def run_inference(model_name, prompt_path):
repo_id = "hpcai-tech/Open-Sora"
config_mapping = {
"OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py",
"OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py",
"OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py"
}
config_path = config_mapping[model_name]
ckpt_path = download_model(repo_id, model_name)
with open(config_path, 'r') as file:
config_content = file.read()
config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_path}"')
with tempfile.NamedTemporaryFile('w', delete=False) as temp_file:
temp_file.write(config_content)
temp_config_path = temp_file.name
cmd = [
"torchrun", "--standalone", "--nproc_per_node", "1",
"scripts/inference.py", temp_config_path,
"--ckpt-path", ckpt_path
]
subprocess.run(cmd, capture_output=True, text=True)
shutil.rmtree(temp_file.name)
# Assuming the output video is saved in a known path, e.g., "./output/video.mp4"
output_video_path = "./output/video.mp4"
return output_video_path
def main():
gr.Interface(
fn=run_inference,
inputs=[
gr.Dropdown(choices=[
"OpenSora-v1-16x256x256.pth",
"OpenSora-v1-HQ-16x256x256.pth",
"OpenSora-v1-HQ-16x512x512.pth"
], label="Model Selection"),
gr.Textbox(label="Prompt Path", value="./assets/texts/t2v_samples.txt")
],
outputs=gr.Video(label="Output Video"),
title="Open-Sora Inference",
description="Run Open-Sora Inference with Custom Parameters"
).launch()
if __name__ == "__main__":
main()
|