File size: 6,548 Bytes
1db4f6f
 
 
 
 
 
 
 
 
8087323
 
1db4f6f
 
ae4c64d
 
1db4f6f
 
 
 
c0fa5a2
1db4f6f
 
 
 
 
3b22b47
1db4f6f
c148ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1db4f6f
 
 
219120c
1db4f6f
 
219120c
 
 
1db4f6f
 
 
ea123d1
 
219120c
 
 
 
 
 
 
 
1db4f6f
219120c
1db4f6f
 
 
219120c
1db4f6f
 
 
219120c
 
 
1db4f6f
 
219120c
 
1db4f6f
219120c
 
1db4f6f
 
219120c
 
 
1db4f6f
 
 
71161a2
1db4f6f
71161a2
 
1db4f6f
 
 
 
 
 
3b22b47
8e429aa
 
 
 
3b22b47
 
1db4f6f
 
 
 
8087323
1db4f6f
 
 
61ade90
1db4f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
c069084
1db4f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e429aa
1db4f6f
 
1677854
1db4f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837b686
1db4f6f
 
a97898c
 
1db4f6f
 
 
 
 
 
1677854
1ac1ff6
1db4f6f
1677854
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import random
import os
import time
from queue import Queue
from threading import Thread

import symusic

import transformers

import spaces
import gradio as gr

os.makedirs('./temp', exist_ok=True)

print('\n\n\n')
print('Loading model...')
pipe = transformers.pipeline(
    "text-generation", 
    model="dx2102/llama-midi",
    torch_dtype="bfloat16",
    device_map="auto",
)
print('Done')

example_prefix = '''pitch duration wait velocity instrument

71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1310 0 20 0
48 330 350 20 0
57 330 350 20 0
66 1310 690 20 0
67 330 350 20 0
69 330 350 20 0
71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1970 0 20 0
48 330 350 20 0
'''



def postprocess(txt, path):

    # track = symusic.core.TrackSecond()
    tracks = {}

    now = 0
    # we need to ignore the invalid output by the model
    try:
        # Remove the first three lines. 1: Score title, 2: The "pitch duration..." hint, 3: Empty line
        for line in txt.split('\n')[3:]:
            pitch, duration, wait, velocity, instrument = line.split()
            pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]]
            if instrument not in tracks:
                tracks[instrument] = symusic.core.TrackSecond()
                if instrument != 'drum':
                    tracks[instrument].program = int(instrument)
                else:
                    tracks[instrument].is_drum = True
            # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter')
            tracks[instrument].notes.append(symusic.core.NoteSecond(
                time=now/1000,
                duration=duration/1000,
                pitch=int(pitch),
                velocity=int(velocity * 4),
            ))
            now += wait
    except Exception as e:
        print('Postprocess: Ignored error:', e)
    
    print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes')

    try:
        # track = symusic.core.TrackSecond()
        # track.notes = symusic.core.NoteSecondList(notes)
        score = symusic.Score(ttype='Second')
        # score.tracks.append(track)
        score.tracks.extend(tracks.values())
        score.dump_midi(path)
    except Exception as e:
        print('Postprocess: Ignored postprocessing error:', e)



with gr.Blocks() as demo:
    chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False)
    prefix_box = gr.TextArea(value="Bach", label="Score title / text prefix")
    with gr.Row():
        submit_btn = gr.Button("Generate")
        clear_btn = gr.Button("Clear history")
    with gr.Row():
        get_audio_btn = gr.Button("Convert to audio")
        get_midi_btn = gr.Button("Convert to MIDI")
    audio_box = gr.Audio()
    midi_box = gr.File()
    piano_roll_box = gr.Image()
    example_box = gr.Examples(
        [
            ["Chopin"], ["Bach"], ["Beethoven"], ["Debussy"], ["Nocturne"], ["Clair De Lune"], ["Guitar"], ["Boogie Woogie"], 
            ["Fugue I in C major, BWV 846"], ["Beethoven Symphony No. 7 (2nd movement) Piano solo"], [example_prefix],
        ],
        inputs=prefix_box,
    )

    def user_fn(user_message, history: list):
        return "", history + [{"role": "user", "content": user_message}]

    @spaces.GPU
    def bot_fn(history: list):
        prefix = history[-1]["content"]
        history.append({"role": "assistant", "content": ""})
        history[-1]["content"] += "Generating with the given prefix...\n"
        queue = Queue()
        class MyStreamer:
            def put(self, tokens):
                for token in tokens.flatten():
                    text = pipe.tokenizer.decode(token.item())
                    if text == '<|begin_of_text|>':
                        continue
                    queue.put(text)
            def end(self):
                queue.put(None)
        def background_fn():
            result = pipe(
                prefix,
                streamer=MyStreamer(), 
                max_length=3000,
                temperature=0.95,
            )
            print('Generated text:')
            print(result[0]['generated_text'])
            print()
        Thread(target=background_fn).start()
        while True:
            text = queue.get()
            if text is None:
                break
            history[-1]["content"] += text
            yield history

    prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
        bot_fn, chatbot_box, chatbot_box
    )
    submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
        bot_fn, chatbot_box, chatbot_box
    )
    clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event], queue=False)

    def get_audio_fn(history):
        i = random.randint(0, 1000_000_000)
        path = f'./temp/{i}.mid'
        try:
            postprocess(history[-1]["content"], path)
        except Exception as e:
            raise gr.Error(f'Error: {type(e)}, {e}')
        # turn midi into audio with timidity
        os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav')
        # wav to mp3
        os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3')
        return f'./temp/{i}.mp3'
    
    get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False)

    def get_midi_fn(history):
        i = random.randint(0, 1000_000_000)
        # turn the text into midi
        try:
            postprocess(history[-1]["content"], f'./temp/{i}.mid')
        except Exception as e:
            raise gr.Error(f'Error: {type(e)}, {e}')
        # also render the piano roll
        import matplotlib.pyplot as plt
        plt.figure(figsize=(12, 4))
        now = 0
        for line in history[-1]["content"].split('\n\n')[-1].split('\n'):
            try:
                pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()]
            except Exception as e:
                continue
            plt.plot([now, now+duration], [pitch, pitch], color='black', alpha=1)
            plt.scatter(now, pitch, s=6, color='black', alpha=0.3)
            now += wait
        plt.savefig(f'./temp/{i}.svg')
        return f'./temp/{i}.mid', f'./temp/{i}.svg'
    
    get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False)

print()
print(gr.__version__)

demo.launch(show_api=False)