Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import subprocess | |
| import tempfile | |
| import shutil | |
| def download_model(repo_id, model_name): | |
| model_path = hf_hub_download(repo_id=repo_id, filename=model_name) | |
| return model_path | |
| def run_inference(model_name, prompt_path): | |
| repo_id = "hpcai-tech/Open-Sora" | |
| # Map model names to their respective configuration files | |
| config_mapping = { | |
| "OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py", | |
| "OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py", | |
| "OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py" | |
| } | |
| # Get the configuration path based on the model name | |
| config_path = config_mapping[model_name] | |
| # Download the selected model | |
| ckpt_path = download_model(repo_id, model_name) | |
| with open(config_path, 'r') as file: | |
| config_content = file.read() | |
| config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_path}"') | |
| with tempfile.NamedTemporaryFile('w', delete=False) as temp_file: | |
| temp_file.write(config_content) | |
| temp_config_path = temp_file.name | |
| cmd = [ | |
| "torchrun", "--standalone", "--nproc_per_node", "1", | |
| "scripts/inference.py", temp_config_path, | |
| "--ckpt-path", ckpt_path | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| shutil.rmtree(temp_file.name) | |
| if result.returncode == 0: | |
| return "Inference completed successfully.", result.stdout | |
| else: | |
| return "Error occurred:", result.stderr | |
| def main(): | |
| gr.Interface( | |
| fn=run_inference, | |
| inputs=[ | |
| gr.Dropdown(choices=[ | |
| "OpenSora-v1-16x256x256.pth", | |
| "OpenSora-v1-HQ-16x256x256.pth", | |
| "OpenSora-v1-HQ-16x512x512.pth" | |
| ], label="Model Selection"), | |
| gr.Textbox(label="Prompt Path", value="./assets/texts/t2v_samples.txt") | |
| ], | |
| outputs=[ | |
| gr.Text(label="Status"), | |
| gr.Text(label="Output") | |
| ], | |
| title="Open-Sora Inference", | |
| description="Run Open-Sora Inference with Custom Parameters" | |
| ).launch() | |
| if __name__ == "__main__": | |
| main() | |