Spaces:
Runtime error
Runtime error
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import subprocess | |
import tempfile | |
import shutil | |
import os | |
import spaces | |
def download_model(repo_id, model_name): | |
model_path = hf_hub_download(repo_id=repo_id, filename=model_name) | |
return model_path | |
def run_inference(model_name, prompt_text): | |
repo_id = "hpcai-tech/Open-Sora" | |
# Map model names to their respective configuration files | |
config_mapping = { | |
"OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py", | |
"OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py", | |
"OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py" | |
} | |
config_path = config_mapping[model_name] | |
ckpt_path = download_model(repo_id, model_name) | |
# Save prompt_text to a temporary text file | |
prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w') | |
prompt_file.write(prompt_text) | |
prompt_file.close() | |
# Read and update the configuration file | |
with open(config_path, 'r') as file: | |
config_content = file.read() | |
config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"') | |
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as temp_file: | |
temp_file.write(config_content) | |
temp_config_path = temp_file.name | |
cmd = [ | |
"torchrun", "--standalone", "--nproc_per_node", "1", | |
"scripts/inference.py", temp_config_path, | |
"--ckpt-path", ckpt_path | |
] | |
result = subprocess.run(cmd, capture_output=True, text=True) | |
print("result", result) | |
# Clean up the temporary files | |
os.remove(temp_file.name) | |
os.remove(prompt_file.name) | |
if result.returncode == 0: | |
# Assuming the output video is saved at a known location, for example "./output/video.mp4" | |
output_video_path = "./output/video.mp4" | |
return output_video_path | |
else: | |
print("Error occurred:", result.stderr) | |
return None # You might want to handle errors differently | |
def main(): | |
gr.Interface( | |
fn=run_inference, | |
inputs=[ | |
gr.Dropdown(choices=[ | |
"OpenSora-v1-16x256x256.pth", | |
"OpenSora-v1-HQ-16x256x256.pth", | |
"OpenSora-v1-HQ-16x512x512.pth" | |
], label="Model Selection"), | |
gr.Textbox(label="Prompt Text", placeholder="Enter prompt text here") | |
], | |
outputs=gr.Video(label="Output Video"), | |
title="Open-Sora Inference", | |
description="Run Open-Sora Inference with Custom Parameters", | |
).launch() | |
if __name__ == "__main__": | |
main() | |