File size: 5,508 Bytes
1db4f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import random
import os
import time
from queue import Queue
from threading import Thread

import symusic

import transformers
import gradio as gr

print('\n\n\n')
print('Loading model...')
# pipe = transformers.pipeline("text-generation", model="openai-community/gpt2")
pipe = transformers.pipeline(
    "text-generation", 
    model="dx2102/llama-midi",
    torch_dtype="bfloat16",
    device_map="auto",
)
print('Done')

default_prefix = '''pitch duration wait

71 1310 0
48 330 350
55 330 350
64 1310 690
74 660 690
69 1310 0
48 330 350
57 330 350
66 1310 690
67 330 350
69 330 350
71 1310 0
48 330 350
55 330 350
64 1310 690
74 660 690
69 1970 0
48 330 350
'''
default_prefix_len = default_prefix.count('\n') - 2


def postprocess(txt, path):
    # saves the text representation to a midi file
    txt = txt.split('\n\n')[-1]

    notes = []
    now = 0
    # we need to ignore the invalid output by the model
    try:
        for line in txt.split('\n'):
            pitch, duration, wait = [int(x) for x in line.split()]
            # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter')
            notes.append(symusic.core.NoteSecond(
                time=now/1000,
                duration=duration/1000,
                pitch=int(pitch),
                velocity=80,
            ))
            now += wait
    except Exception as e:
        print('Ignored error:', e)

    try:
        track = symusic.core.TrackSecond()
        track.notes = symusic.core.NoteSecondList(notes)
        score = symusic.Score(ttype='Second')
        score.tracks.append(track)
        score.dump_midi(path)
    except Exception as e:
        print('Ignored error:', e)

with gr.Blocks() as demo:
    chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False)
    prefix_box = gr.Textbox(value=default_prefix, label="prefix")
    with gr.Row():
        submit_btn = gr.Button("Submit")
        clear_btn = gr.Button("Clear")
    with gr.Row():
        get_audio_btn = gr.Button("Convert to audio")
        get_midi_btn = gr.Button("Convert to MIDI")
    audio_box = gr.Audio()
    midi_box = gr.File()
    piano_roll_box = gr.Image()

    def user_fn(user_message, history: list):
        return "", history + [{"role": "user", "content": user_message}]

    def bot_fn(history: list):
        prefix = history[-1]["content"]
        history.append({"role": "assistant", "content": ""})
        if prefix.startswith("pitch duration wait\n\n"):
            history[-1]["content"] += "Generating with the given prefix...\n"
        else:
            history[-1]["content"] += f"Generating from scratch with a default prefix of {default_prefix_len} notes...\n"
            prefix = default_prefix
        queue = Queue()
        class MyStreamer:
            def put(self, tokens):
                for token in tokens.flatten():
                    text = pipe.tokenizer.decode(token.item())
                    if text == '<|begin_of_text|>':
                        continue
                    queue.put(text)
            def end(self):
                queue.put(None)
        def background_fn():
            result = pipe(
                prefix,
                streamer=MyStreamer(), 
                max_length=1000,
                temperature=0.95,
            )
            print('Generated text:')
            print(result[0]['generated_text'])
            print()
        Thread(target=background_fn).start()
        while True:
            text = queue.get()
            if text is None:
                break
            history[-1]["content"] += text
            yield history

    prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
        bot_fn, chatbot_box, chatbot_box
    )
    submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
        bot_fn, chatbot_box, chatbot_box
    )
    clear_btn.click(lambda: None, None, chatbot_box, queue=False)

    def get_audio_fn(history):
        i = random.randint(0, 1000_000_000)
        path = f'./temp/{i}.mid'
        try:
            postprocess(history[-1]["content"], path)
        except Exception as e:
            raise gr.Error(f'Error: {type(e)}, {e}')
        # turn midi into audio with timidity
        os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav')
        # wav to mp3
        os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3')
        return f'./temp/{i}.mp3'
    
    get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False)

    def get_midi_fn(history):
        i = random.randint(0, 1000_000_000)
        # turn the text into midi
        try:
            postprocess(history[-1]["content"], f'./temp/{i}.mid')
        except Exception as e:
            raise gr.Error(f'Error: {type(e)}, {e}')
        # also render the piano roll
        import matplotlib.pyplot as plt
        plt.figure(figsize=(12, 4))
        now = 0
        for line in history[-1]["content"].split('\n\n')[-1].split('\n'):
            try:
                pitch, duration, wait = [int(x) for x in line.split()]
            except Exception as e:
                continue
            plt.plot([now, now+duration], [pitch, pitch], color='black')
            now += wait
        plt.savefig(f'./temp/{i}.svg')
        return f'./temp/{i}.mid', f'./temp/{i}.svg'
    
    get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False)


demo.launch()