Spaces:
Runtime error
Runtime error
import argparse | |
import tempfile | |
import os | |
import subprocess | |
import sys | |
def main(): | |
parser = argparse.ArgumentParser(description='Run YuE model with direct input') | |
parser.add_argument('--genre', type=str, required=True, help='Genre tags for the music') | |
parser.add_argument('--lyrics', type=str, required=True, help='Lyrics for the music') | |
parser.add_argument('--run_n_segments', type=int, default=2, help='Number of segments to process') | |
parser.add_argument('--stage2_batch_size', type=int, default=4, help='Batch size for stage 2') | |
parser.add_argument('--max_new_tokens', type=int, default=3000, help='Maximum number of new tokens') | |
parser.add_argument('--cuda_idx', type=int, default=0, help='CUDA device index') | |
args = parser.parse_args() | |
# Create temporary files for genre and lyrics | |
with tempfile.NamedTemporaryFile(mode='w', delete=False) as genre_file: | |
genre_file.write(args.genre) | |
genre_path = genre_file.name | |
with tempfile.NamedTemporaryFile(mode='w', delete=False) as lyrics_file: | |
lyrics_file.write(args.lyrics) | |
lyrics_path = lyrics_file.name | |
output_dir = '/home/user/app/output' | |
try: | |
# Get the directory where wrapper.py is located | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
# Run the inference script | |
subprocess.run([ | |
'python', os.path.join(current_dir, 'infer.py'), | |
'--stage1_model', 'm-a-p/YuE-s1-7B-anneal-en-cot', | |
'--stage2_model', 'm-a-p/YuE-s2-1B-general', | |
'--genre_txt', genre_path, | |
'--lyrics_txt', lyrics_path, | |
'--run_n_segments', str(args.run_n_segments), | |
'--stage2_batch_size', str(args.stage2_batch_size), | |
'--output_dir', output_dir, | |
'--cuda_idx', str(args.cuda_idx), | |
'--max_new_tokens', str(args.max_new_tokens) | |
], check=True) | |
print(f"\nOutput directory: {output_dir}") | |
print("Generated files:") | |
for file in os.listdir(output_dir): | |
print(f"- {os.path.join(output_dir, file)}") | |
finally: | |
# Clean up temporary files | |
os.unlink(genre_path) | |
os.unlink(lyrics_path) | |
if __name__ == '__main__': | |
main() |