kadirnar commited on
Commit
b0f3abf
·
verified ·
1 Parent(s): 50eb146

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -8,9 +8,19 @@ def download_model(repo_id, model_name):
8
  model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
9
  return model_path
10
 
11
- def run_inference(config_path, model_name, prompt_path):
12
  repo_id = "hpcai-tech/Open-Sora"
13
 
 
 
 
 
 
 
 
 
 
 
14
  # Download the selected model
15
  ckpt_path = download_model(repo_id, model_name)
16
 
@@ -29,7 +39,7 @@ def run_inference(config_path, model_name, prompt_path):
29
  ]
30
  result = subprocess.run(cmd, capture_output=True, text=True)
31
 
32
- shutil.rmtree(temp_config_path)
33
 
34
  if result.returncode == 0:
35
  return "Inference completed successfully.", result.stdout
@@ -40,7 +50,6 @@ def main():
40
  gr.Interface(
41
  fn=run_inference,
42
  inputs=[
43
- gr.Textbox(label="Configuration Path", value="configs/opensora/inference/16x256x256.py"),
44
  gr.Dropdown(choices=[
45
  "OpenSora-v1-16x256x256.pth",
46
  "OpenSora-v1-HQ-16x256x256.pth",
 
8
  model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
9
  return model_path
10
 
11
+ def run_inference(model_name, prompt_path):
12
  repo_id = "hpcai-tech/Open-Sora"
13
 
14
+ # Map model names to their respective configuration files
15
+ config_mapping = {
16
+ "OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py",
17
+ "OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py",
18
+ "OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py"
19
+ }
20
+
21
+ # Get the configuration path based on the model name
22
+ config_path = config_mapping[model_name]
23
+
24
  # Download the selected model
25
  ckpt_path = download_model(repo_id, model_name)
26
 
 
39
  ]
40
  result = subprocess.run(cmd, capture_output=True, text=True)
41
 
42
+ shutil.rmtree(temp_file.name)
43
 
44
  if result.returncode == 0:
45
  return "Inference completed successfully.", result.stdout
 
50
  gr.Interface(
51
  fn=run_inference,
52
  inputs=[
 
53
  gr.Dropdown(choices=[
54
  "OpenSora-v1-16x256x256.pth",
55
  "OpenSora-v1-HQ-16x256x256.pth",