dx2102 commited on
Commit
1db4f6f
·
verified ·
1 Parent(s): a4ba5ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import os
3
+ import time
4
+ from queue import Queue
5
+ from threading import Thread
6
+
7
+ import symusic
8
+
9
+ import transformers
10
+ import gradio as gr
11
+
12
+ print('\n\n\n')
13
+ print('Loading model...')
14
+ # pipe = transformers.pipeline("text-generation", model="openai-community/gpt2")
15
+ pipe = transformers.pipeline(
16
+ "text-generation",
17
+ model="dx2102/llama-midi",
18
+ torch_dtype="bfloat16",
19
+ device_map="auto",
20
+ )
21
+ print('Done')
22
+
23
+ default_prefix = '''pitch duration wait
24
+
25
+ 71 1310 0
26
+ 48 330 350
27
+ 55 330 350
28
+ 64 1310 690
29
+ 74 660 690
30
+ 69 1310 0
31
+ 48 330 350
32
+ 57 330 350
33
+ 66 1310 690
34
+ 67 330 350
35
+ 69 330 350
36
+ 71 1310 0
37
+ 48 330 350
38
+ 55 330 350
39
+ 64 1310 690
40
+ 74 660 690
41
+ 69 1970 0
42
+ 48 330 350
43
+ '''
44
+ default_prefix_len = default_prefix.count('\n') - 2
45
+
46
+
47
+ def postprocess(txt, path):
48
+ # saves the text representation to a midi file
49
+ txt = txt.split('\n\n')[-1]
50
+
51
+ notes = []
52
+ now = 0
53
+ # we need to ignore the invalid output by the model
54
+ try:
55
+ for line in txt.split('\n'):
56
+ pitch, duration, wait = [int(x) for x in line.split()]
57
+ # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter')
58
+ notes.append(symusic.core.NoteSecond(
59
+ time=now/1000,
60
+ duration=duration/1000,
61
+ pitch=int(pitch),
62
+ velocity=80,
63
+ ))
64
+ now += wait
65
+ except Exception as e:
66
+ print('Ignored error:', e)
67
+
68
+ try:
69
+ track = symusic.core.TrackSecond()
70
+ track.notes = symusic.core.NoteSecondList(notes)
71
+ score = symusic.Score(ttype='Second')
72
+ score.tracks.append(track)
73
+ score.dump_midi(path)
74
+ except Exception as e:
75
+ print('Ignored error:', e)
76
+
77
+ with gr.Blocks() as demo:
78
+ chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False)
79
+ prefix_box = gr.Textbox(value=default_prefix, label="prefix")
80
+ with gr.Row():
81
+ submit_btn = gr.Button("Submit")
82
+ clear_btn = gr.Button("Clear")
83
+ with gr.Row():
84
+ get_audio_btn = gr.Button("Convert to audio")
85
+ get_midi_btn = gr.Button("Convert to MIDI")
86
+ audio_box = gr.Audio()
87
+ midi_box = gr.File()
88
+ piano_roll_box = gr.Image()
89
+
90
+ def user_fn(user_message, history: list):
91
+ return "", history + [{"role": "user", "content": user_message}]
92
+
93
+ def bot_fn(history: list):
94
+ prefix = history[-1]["content"]
95
+ history.append({"role": "assistant", "content": ""})
96
+ if prefix.startswith("pitch duration wait\n\n"):
97
+ history[-1]["content"] += "Generating with the given prefix...\n"
98
+ else:
99
+ history[-1]["content"] += f"Generating from scratch with a default prefix of {default_prefix_len} notes...\n"
100
+ prefix = default_prefix
101
+ queue = Queue()
102
+ class MyStreamer:
103
+ def put(self, tokens):
104
+ for token in tokens.flatten():
105
+ text = pipe.tokenizer.decode(token.item())
106
+ if text == '<|begin_of_text|>':
107
+ continue
108
+ queue.put(text)
109
+ def end(self):
110
+ queue.put(None)
111
+ def background_fn():
112
+ result = pipe(
113
+ prefix,
114
+ streamer=MyStreamer(),
115
+ max_length=1000,
116
+ temperature=0.95,
117
+ )
118
+ print('Generated text:')
119
+ print(result[0]['generated_text'])
120
+ print()
121
+ Thread(target=background_fn).start()
122
+ while True:
123
+ text = queue.get()
124
+ if text is None:
125
+ break
126
+ history[-1]["content"] += text
127
+ yield history
128
+
129
+ prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
130
+ bot_fn, chatbot_box, chatbot_box
131
+ )
132
+ submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
133
+ bot_fn, chatbot_box, chatbot_box
134
+ )
135
+ clear_btn.click(lambda: None, None, chatbot_box, queue=False)
136
+
137
+ def get_audio_fn(history):
138
+ i = random.randint(0, 1000_000_000)
139
+ path = f'./temp/{i}.mid'
140
+ try:
141
+ postprocess(history[-1]["content"], path)
142
+ except Exception as e:
143
+ raise gr.Error(f'Error: {type(e)}, {e}')
144
+ # turn midi into audio with timidity
145
+ os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav')
146
+ # wav to mp3
147
+ os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3')
148
+ return f'./temp/{i}.mp3'
149
+
150
+ get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False)
151
+
152
+ def get_midi_fn(history):
153
+ i = random.randint(0, 1000_000_000)
154
+ # turn the text into midi
155
+ try:
156
+ postprocess(history[-1]["content"], f'./temp/{i}.mid')
157
+ except Exception as e:
158
+ raise gr.Error(f'Error: {type(e)}, {e}')
159
+ # also render the piano roll
160
+ import matplotlib.pyplot as plt
161
+ plt.figure(figsize=(12, 4))
162
+ now = 0
163
+ for line in history[-1]["content"].split('\n\n')[-1].split('\n'):
164
+ try:
165
+ pitch, duration, wait = [int(x) for x in line.split()]
166
+ except Exception as e:
167
+ continue
168
+ plt.plot([now, now+duration], [pitch, pitch], color='black')
169
+ now += wait
170
+ plt.savefig(f'./temp/{i}.svg')
171
+ return f'./temp/{i}.mid', f'./temp/{i}.svg'
172
+
173
+ get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False)
174
+
175
+
176
+ demo.launch()