kadirnar commited on
Commit
b205b1d
·
verified ·
1 Parent(s): 1a68b3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -10
app.py CHANGED
@@ -31,6 +31,7 @@ def download_model(repo_id, model_name):
31
  model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
32
  return model_path
33
 
 
34
 
35
  @spaces.GPU
36
  def run_inference(model_name, prompt_text):
@@ -51,7 +52,6 @@ def run_inference(model_name, prompt_text):
51
  prompt_file.write(prompt_text)
52
  prompt_file.close()
53
 
54
- # Read and update the configuration file
55
  with open(config_path, 'r') as file:
56
  config_content = file.read()
57
  config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"')
@@ -60,24 +60,23 @@ def run_inference(model_name, prompt_text):
60
  temp_file.write(config_content)
61
  temp_config_path = temp_file.name
62
 
63
-
64
  cmd = [
65
  "torchrun", "--standalone", "--nproc_per_node", "1",
66
  "scripts/inference.py", temp_config_path,
67
  "--ckpt-path", ckpt_path
68
  ]
69
- result = subprocess.run(cmd, capture_output=True, text=True)
 
 
 
 
 
 
70
  # Clean up the temporary files
71
  os.remove(temp_file.name)
72
  os.remove(prompt_file.name)
73
 
74
- if result.returncode == 0:
75
- # Assuming the output video is saved at a known location, for example "./output/video.mp4"
76
- output_video_path = "./output/video.mp4"
77
- return output_video_path
78
- else:
79
- print("Error occurred:", result.stderr)
80
- return None # You might want to handle errors differently
81
 
82
  def main():
83
  gr.Interface(
 
31
  model_path = hf_hub_download(repo_id=repo_id, filename=model_name)
32
  return model_path
33
 
34
+ import glob
35
 
36
  @spaces.GPU
37
  def run_inference(model_name, prompt_text):
 
52
  prompt_file.write(prompt_text)
53
  prompt_file.close()
54
 
 
55
  with open(config_path, 'r') as file:
56
  config_content = file.read()
57
  config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"')
 
60
  temp_file.write(config_content)
61
  temp_config_path = temp_file.name
62
 
 
63
  cmd = [
64
  "torchrun", "--standalone", "--nproc_per_node", "1",
65
  "scripts/inference.py", temp_config_path,
66
  "--ckpt-path", ckpt_path
67
  ]
68
+ subprocess.run(cmd)
69
+
70
+ # Find the latest generated file in the save directory
71
+ save_dir = "./output" # Or the appropriate directory where files are saved
72
+ list_of_files = glob.glob(f'{save_dir}/*') # You might need to adjust the pattern
73
+ latest_file = max(list_of_files, key=os.path.getctime)
74
+
75
  # Clean up the temporary files
76
  os.remove(temp_file.name)
77
  os.remove(prompt_file.name)
78
 
79
+ return latest_file
 
 
 
 
 
 
80
 
81
  def main():
82
  gr.Interface(