llama-midi / app.py
dx2102's picture
Update app.py
1ac1ff6 verified
raw
history blame
6.55 kB
import random
import os
import time
from queue import Queue
from threading import Thread
import symusic
import transformers
import spaces
import gradio as gr
os.makedirs('./temp', exist_ok=True)
print('\n\n\n')
print('Loading model...')
pipe = transformers.pipeline(
"text-generation",
model="dx2102/llama-midi",
torch_dtype="bfloat16",
device_map="auto",
)
print('Done')
example_prefix = '''pitch duration wait velocity instrument
71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1310 0 20 0
48 330 350 20 0
57 330 350 20 0
66 1310 690 20 0
67 330 350 20 0
69 330 350 20 0
71 1310 0 20 0
48 330 350 20 0
55 330 350 20 0
64 1310 690 20 0
74 660 690 20 0
69 1970 0 20 0
48 330 350 20 0
'''
def postprocess(txt, path):
# track = symusic.core.TrackSecond()
tracks = {}
now = 0
# we need to ignore the invalid output by the model
try:
# Remove the first three lines. 1: Score title, 2: The "pitch duration..." hint, 3: Empty line
for line in txt.split('\n')[3:]:
pitch, duration, wait, velocity, instrument = line.split()
pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]]
if instrument not in tracks:
tracks[instrument] = symusic.core.TrackSecond()
if instrument != 'drum':
tracks[instrument].program = int(instrument)
else:
tracks[instrument].is_drum = True
# Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter')
tracks[instrument].notes.append(symusic.core.NoteSecond(
time=now/1000,
duration=duration/1000,
pitch=int(pitch),
velocity=int(velocity * 4),
))
now += wait
except Exception as e:
print('Postprocess: Ignored error:', e)
print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes')
try:
# track = symusic.core.TrackSecond()
# track.notes = symusic.core.NoteSecondList(notes)
score = symusic.Score(ttype='Second')
# score.tracks.append(track)
score.tracks.extend(tracks.values())
score.dump_midi(path)
except Exception as e:
print('Postprocess: Ignored postprocessing error:', e)
with gr.Blocks() as demo:
chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False)
prefix_box = gr.TextArea(value="Bach", label="Score title / text prefix")
with gr.Row():
submit_btn = gr.Button("Generate")
clear_btn = gr.Button("Clear history")
with gr.Row():
get_audio_btn = gr.Button("Convert to audio")
get_midi_btn = gr.Button("Convert to MIDI")
audio_box = gr.Audio()
midi_box = gr.File()
piano_roll_box = gr.Image()
example_box = gr.Examples(
[
["Chopin"], ["Bach"], ["Beethoven"], ["Debussy"], ["Nocturne"], ["Clair De Lune"], ["Guitar"], ["Boogie Woogie"],
["Fugue I in C major, BWV 846"], ["Beethoven Symphony No. 7 (2nd movement) Piano solo"], [example_prefix],
],
inputs=prefix_box,
)
def user_fn(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}]
@spaces.GPU
def bot_fn(history: list):
prefix = history[-1]["content"]
history.append({"role": "assistant", "content": ""})
history[-1]["content"] += "Generating with the given prefix...\n"
queue = Queue()
class MyStreamer:
def put(self, tokens):
for token in tokens.flatten():
text = pipe.tokenizer.decode(token.item())
if text == '<|begin_of_text|>':
continue
queue.put(text)
def end(self):
queue.put(None)
def background_fn():
result = pipe(
prefix,
streamer=MyStreamer(),
max_length=3000,
temperature=0.95,
)
print('Generated text:')
print(result[0]['generated_text'])
print()
Thread(target=background_fn).start()
while True:
text = queue.get()
if text is None:
break
history[-1]["content"] += text
yield history
prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
bot_fn, chatbot_box, chatbot_box
)
submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
bot_fn, chatbot_box, chatbot_box
)
clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event], queue=False)
def get_audio_fn(history):
i = random.randint(0, 1000_000_000)
path = f'./temp/{i}.mid'
try:
postprocess(history[-1]["content"], path)
except Exception as e:
raise gr.Error(f'Error: {type(e)}, {e}')
# turn midi into audio with timidity
os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav')
# wav to mp3
os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3')
return f'./temp/{i}.mp3'
get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False)
def get_midi_fn(history):
i = random.randint(0, 1000_000_000)
# turn the text into midi
try:
postprocess(history[-1]["content"], f'./temp/{i}.mid')
except Exception as e:
raise gr.Error(f'Error: {type(e)}, {e}')
# also render the piano roll
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
now = 0
for line in history[-1]["content"].split('\n\n')[-1].split('\n'):
try:
pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()]
except Exception as e:
continue
plt.plot([now, now+duration], [pitch, pitch], color='black', alpha=1)
plt.scatter(now, pitch, s=6, color='black', alpha=0.3)
now += wait
plt.savefig(f'./temp/{i}.svg')
return f'./temp/{i}.mid', f'./temp/{i}.svg'
get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False)
print()
print(gr.__version__)
demo.launch(show_api=False)