File size: 2,282 Bytes
b30a088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e806e8
 
b30a088
8e806e8
 
 
b30a088
 
8e806e8
b30a088
 
 
 
 
 
8e806e8
b30a088
 
 
 
8e806e8
 
 
 
 
b30a088
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import argparse
import tempfile
import os
import subprocess
import sys

def main():
    parser = argparse.ArgumentParser(description='Run YuE model with direct input')
    parser.add_argument('--genre', type=str, required=True, help='Genre tags for the music')
    parser.add_argument('--lyrics', type=str, required=True, help='Lyrics for the music')
    parser.add_argument('--run_n_segments', type=int, default=2, help='Number of segments to process')
    parser.add_argument('--stage2_batch_size', type=int, default=4, help='Batch size for stage 2')
    parser.add_argument('--max_new_tokens', type=int, default=3000, help='Maximum number of new tokens')
    parser.add_argument('--cuda_idx', type=int, default=0, help='CUDA device index')
    
    args = parser.parse_args()

    # Create temporary files for genre and lyrics
    with tempfile.NamedTemporaryFile(mode='w', delete=False) as genre_file:
        genre_file.write(args.genre)
        genre_path = genre_file.name

    with tempfile.NamedTemporaryFile(mode='w', delete=False) as lyrics_file:
        lyrics_file.write(args.lyrics)
        lyrics_path = lyrics_file.name

    output_dir = '/home/user/app/output'
    
    try:
        # Get the directory where wrapper.py is located
        current_dir = os.path.dirname(os.path.abspath(__file__))
        
        # Run the inference script
        subprocess.run([
            'python', os.path.join(current_dir, 'infer.py'),
            '--stage1_model', 'm-a-p/YuE-s1-7B-anneal-en-cot',
            '--stage2_model', 'm-a-p/YuE-s2-1B-general',
            '--genre_txt', genre_path,
            '--lyrics_txt', lyrics_path,
            '--run_n_segments', str(args.run_n_segments),
            '--stage2_batch_size', str(args.stage2_batch_size),
            '--output_dir', output_dir,
            '--cuda_idx', str(args.cuda_idx),
            '--max_new_tokens', str(args.max_new_tokens)
        ], check=True)
        
        print(f"\nOutput directory: {output_dir}")
        print("Generated files:")
        for file in os.listdir(output_dir):
            print(f"- {os.path.join(output_dir, file)}")
        
    finally:
        # Clean up temporary files
        os.unlink(genre_path)
        os.unlink(lyrics_path)

if __name__ == '__main__':
    main()