File size: 7,207 Bytes
1db4f6f
 
 
 
 
 
 
 
 
8087323
 
1db4f6f
 
ae4c64d
 
1db4f6f
 
 
 
c0fa5a2
36bdd2d
1db4f6f
a53a488
1db4f6f
 
 
3b22b47
1db4f6f
c148ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1db4f6f
 
 
219120c
1db4f6f
040724b
 
219120c
 
 
1db4f6f
040724b
 
 
219120c
 
 
 
 
 
 
 
1db4f6f
219120c
1db4f6f
 
 
219120c
1db4f6f
 
040724b
 
219120c
 
1db4f6f
 
219120c
 
1db4f6f
219120c
 
1db4f6f
 
219120c
 
 
1db4f6f
 
 
184f08d
1db4f6f
71161a2
 
1db4f6f
 
 
 
 
 
3b22b47
8e429aa
688498f
6124855
 
aa2501e
6124855
aa2501e
8e429aa
3b22b47
688498f
3b22b47
1db4f6f
 
 
 
8087323
1db4f6f
 
fb41ae4
 
 
 
 
1db4f6f
61ade90
56eca40
1db4f6f
 
 
 
 
 
56eca40
1db4f6f
 
 
6130705
 
 
 
 
 
 
 
 
1db4f6f
 
 
 
 
6130705
1db4f6f
 
 
 
 
 
 
 
8e429aa
1db4f6f
 
1677854
1db4f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837b686
1db4f6f
 
a97898c
 
1db4f6f
 
 
 
 
 
688498f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import random
import os
import time
from queue import Queue
from threading import Thread

import symusic

import transformers

import spaces
import gradio as gr

os.makedirs('./temp', exist_ok=True)

print('\n\n\n')
print('Loading model...')
pipe = transformers.pipeline(
    "text-generation", 
    model="dx2102/llama-midi",
    # revision="c303c108399aba837146e893375849b918f413b3",
    torch_dtype="bfloat16",
    device="cuda",
)
print('Done')

example_prefix = '''pitch duration wait velocity instrument

71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1310 0 20 0
48 330 350 20 0
57 330 350 20 0
66 1310 690 20 0
67 330 350 20 0
69 330 350 20 0
71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1970 0 20 0
48 330 350 20 0
'''



def postprocess(txt, path):
    # remove prefix
    txt = txt.split('\n\n')[-1]
    # track = symusic.core.TrackSecond()
    tracks = {}

    now = 0
    for line in txt.split('\n'):
        # we need to ignore the invalid output by the model
        try:
            pitch, duration, wait, velocity, instrument = line.split()
            pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]]
            if instrument not in tracks:
                tracks[instrument] = symusic.core.TrackSecond()
                if instrument != 'drum':
                    tracks[instrument].program = int(instrument)
                else:
                    tracks[instrument].is_drum = True
            # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter')
            tracks[instrument].notes.append(symusic.core.NoteSecond(
                time=now/1000,
                duration=duration/1000,
                pitch=int(pitch),
                velocity=int(velocity * 4),
            ))
            now += wait
        except Exception as e:
            print(f'Postprocess: Ignored line: "{line}" because of error:', e)
    
    print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes')

    try:
        # track = symusic.core.TrackSecond()
        # track.notes = symusic.core.NoteSecondList(notes)
        score = symusic.Score(ttype='Second')
        # score.tracks.append(track)
        score.tracks.extend(tracks.values())
        score.dump_midi(path)
    except Exception as e:
        print('Postprocess: Ignored postprocessing error:', e)



with gr.Blocks() as demo:
    chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False)
    prefix_box = gr.TextArea(value="Twinkle Twinkle Little Star", label="Score title / text prefix")
    with gr.Row():
        submit_btn = gr.Button("Generate")
        clear_btn = gr.Button("Clear history")
    with gr.Row():
        get_audio_btn = gr.Button("Convert to audio")
        get_midi_btn = gr.Button("Convert to MIDI")
    audio_box = gr.Audio()
    midi_box = gr.File()
    piano_roll_box = gr.Image()
    example_box = gr.Examples(
        [
            [example_prefix],
            ["Twinkle Twinkle Little Star"], ["Twinkle Twinkle Little Star (Minor Key Version)"],
            ["The Entertainer - Scott Joplin (Piano Solo)"], ["Clair de Lune – Debussy"], ["Nocturne | Frederic Chopin"],
            ["Fugue I in C major, BWV 846"], ["Beethoven Symphony No. 7 (2nd movement) Piano solo"], 
            ["Guitar"], 
            # ["Composer: Chopin"], ["Composer: Bach"], ["Composer: Beethoven"], ["Composer: Debussy"], 
        ],
        inputs=prefix_box,
        examples_per_page=9999,
    )

    def user_fn(user_message, history: list):
        return "", history + [{"role": "user", "content": user_message}]

    @spaces.GPU
    def bot_fn(history: list):
        prefix = history[-1]["content"]
        # prevent the model from continuing user's score title
        if prefix != '' and '\n' not in prefix:
            # prefix is a single line --> prefix is the score title
            prefix += '\n'
            
        history.append({"role": "assistant", "content": ""})
        history[-1]["content"] += "Generating with the given prefix...\n"
        queue = Queue(maxsize=10)
        class MyStreamer:
            def put(self, tokens):
                for token in tokens.flatten():
                    text = pipe.tokenizer.decode(token.item())
                    if text == '<|begin_of_text|>':
                        continue
                    queue.put(text, block=True, timeout=5)
            def end(self):
                queue.put(None)
        def background_fn():
            try:
                result = pipe(
                    prefix,
                    streamer=MyStreamer(), 
                    max_length=3000,
                    temperature=0.95,
                )
            except queue.Full:
                print("Queue is full. Exiting.")
            print('Generated text:')
            print(result[0]['generated_text'])
            print()
        Thread(target=background_fn).start()
        while True:
            text = queue.get()
            if text is None:
                break
            history[-1]["content"] += text
            yield history

    prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
        bot_fn, chatbot_box, chatbot_box
    )
    submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
        bot_fn, chatbot_box, chatbot_box
    )
    clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event], queue=False)

    def get_audio_fn(history):
        i = random.randint(0, 1000_000_000)
        path = f'./temp/{i}.mid'
        try:
            postprocess(history[-1]["content"], path)
        except Exception as e:
            raise gr.Error(f'Error: {type(e)}, {e}')
        # turn midi into audio with timidity
        os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav')
        # wav to mp3
        os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3')
        return f'./temp/{i}.mp3'
    
    get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False)

    def get_midi_fn(history):
        i = random.randint(0, 1000_000_000)
        # turn the text into midi
        try:
            postprocess(history[-1]["content"], f'./temp/{i}.mid')
        except Exception as e:
            raise gr.Error(f'Error: {type(e)}, {e}')
        # also render the piano roll
        import matplotlib.pyplot as plt
        plt.figure(figsize=(12, 4))
        now = 0
        for line in history[-1]["content"].split('\n\n')[-1].split('\n'):
            try:
                pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()]
            except Exception as e:
                continue
            plt.plot([now, now+duration], [pitch, pitch], color='black', alpha=1)
            plt.scatter(now, pitch, s=6, color='black', alpha=0.3)
            now += wait
        plt.savefig(f'./temp/{i}.svg')
        return f'./temp/{i}.mid', f'./temp/{i}.svg'
    
    get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False)

demo.launch()