Spaces:
Running
on
Zero
Running
on
Zero
import random | |
import os | |
import time | |
from queue import Queue | |
from threading import Thread | |
import symusic | |
import transformers | |
import spaces | |
import gradio as gr | |
os.makedirs('./temp', exist_ok=True) | |
print('\n\n\n') | |
print('Loading model...') | |
pipe = transformers.pipeline( | |
"text-generation", | |
model="dx2102/llama-midi", | |
# revision="c303c108399aba837146e893375849b918f413b3", | |
torch_dtype="bfloat16", | |
device="cuda", | |
) | |
print('Done') | |
example_prefix = '''pitch duration wait velocity instrument | |
71 1310 0 20 0 | |
48 330 350 20 0 | |
55 330 350 20 0 | |
64 1310 690 20 0 | |
74 660 690 20 0 | |
69 1310 0 20 0 | |
48 330 350 20 0 | |
57 330 350 20 0 | |
66 1310 690 20 0 | |
67 330 350 20 0 | |
69 330 350 20 0 | |
71 1310 0 20 0 | |
48 330 350 20 0 | |
55 330 350 20 0 | |
64 1310 690 20 0 | |
74 660 690 20 0 | |
69 1970 0 20 0 | |
48 330 350 20 0 | |
''' | |
def postprocess(txt, path): | |
# remove prefix | |
txt = txt.split('\n\n')[-1] | |
# track = symusic.core.TrackSecond() | |
tracks = {} | |
now = 0 | |
for line in txt.split('\n'): | |
# we need to ignore the invalid output by the model | |
try: | |
pitch, duration, wait, velocity, instrument = line.split() | |
pitch, duration, wait, velocity = [int(x) for x in [pitch, duration, wait, velocity]] | |
if instrument not in tracks: | |
tracks[instrument] = symusic.core.TrackSecond() | |
if instrument != 'drum': | |
tracks[instrument].program = int(instrument) | |
else: | |
tracks[instrument].is_drum = True | |
# Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter') | |
tracks[instrument].notes.append(symusic.core.NoteSecond( | |
time=now/1000, | |
duration=duration/1000, | |
pitch=int(pitch), | |
velocity=int(velocity * 4), | |
)) | |
now += wait | |
except Exception as e: | |
print(f'Postprocess: Ignored line: "{line}" because of error:', e) | |
print(f'Postprocess: Got {sum(len(track.notes) for track in tracks.values())} notes') | |
try: | |
# track = symusic.core.TrackSecond() | |
# track.notes = symusic.core.NoteSecondList(notes) | |
score = symusic.Score(ttype='Second') | |
# score.tracks.append(track) | |
score.tracks.extend(tracks.values()) | |
score.dump_midi(path) | |
except Exception as e: | |
print('Postprocess: Ignored postprocessing error:', e) | |
with gr.Blocks() as demo: | |
chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False) | |
prefix_box = gr.TextArea(value="Twinkle Twinkle Little Star", label="Score title / text prefix") | |
with gr.Row(): | |
submit_btn = gr.Button("Generate") | |
clear_btn = gr.Button("Clear history") | |
with gr.Row(): | |
get_audio_btn = gr.Button("Convert to audio") | |
get_midi_btn = gr.Button("Convert to MIDI") | |
audio_box = gr.Audio() | |
midi_box = gr.File() | |
piano_roll_box = gr.Image() | |
example_box = gr.Examples( | |
[ | |
[example_prefix], | |
["Twinkle Twinkle Little Star"], ["Twinkle Twinkle Little Star (Minor Key Version)"], | |
["The Entertainer - Scott Joplin (Piano Solo)"], ["Clair de Lune – Debussy"], ["Nocturne | Frederic Chopin"], | |
["Fugue I in C major, BWV 846"], ["Beethoven Symphony No. 7 (2nd movement) Piano solo"], | |
["Guitar"], | |
# ["Composer: Chopin"], ["Composer: Bach"], ["Composer: Beethoven"], ["Composer: Debussy"], | |
], | |
inputs=prefix_box, | |
examples_per_page=9999, | |
) | |
def user_fn(user_message, history: list): | |
return "", history + [{"role": "user", "content": user_message}] | |
def bot_fn(history: list): | |
prefix = history[-1]["content"] | |
# prevent the model from continuing user's score title | |
if prefix != '' and '\n' not in prefix: | |
# prefix is a single line --> prefix is the score title | |
prefix += '\n' | |
history.append({"role": "assistant", "content": ""}) | |
history[-1]["content"] += "Generating with the given prefix...\n" | |
queue = Queue(maxsize=10) | |
class MyStreamer: | |
def put(self, tokens): | |
for token in tokens.flatten(): | |
text = pipe.tokenizer.decode(token.item()) | |
if text == '<|begin_of_text|>': | |
continue | |
queue.put(text, block=True, timeout=5) | |
def end(self): | |
queue.put(None) | |
def background_fn(): | |
try: | |
result = pipe( | |
prefix, | |
streamer=MyStreamer(), | |
max_length=3000, | |
temperature=0.95, | |
) | |
except queue.Full: | |
print("Queue is full. Exiting.") | |
print('Generated text:') | |
print(result[0]['generated_text']) | |
print() | |
Thread(target=background_fn).start() | |
while True: | |
text = queue.get() | |
if text is None: | |
break | |
history[-1]["content"] += text | |
yield history | |
prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then( | |
bot_fn, chatbot_box, chatbot_box | |
) | |
submit_event = submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then( | |
bot_fn, chatbot_box, chatbot_box | |
) | |
clear_btn.click(lambda: None, inputs=[], outputs=chatbot_box, cancels=[submit_event], queue=False) | |
def get_audio_fn(history): | |
i = random.randint(0, 1000_000_000) | |
path = f'./temp/{i}.mid' | |
try: | |
postprocess(history[-1]["content"], path) | |
except Exception as e: | |
raise gr.Error(f'Error: {type(e)}, {e}') | |
# turn midi into audio with timidity | |
os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav') | |
# wav to mp3 | |
os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3') | |
return f'./temp/{i}.mp3' | |
get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False) | |
def get_midi_fn(history): | |
i = random.randint(0, 1000_000_000) | |
# turn the text into midi | |
try: | |
postprocess(history[-1]["content"], f'./temp/{i}.mid') | |
except Exception as e: | |
raise gr.Error(f'Error: {type(e)}, {e}') | |
# also render the piano roll | |
import matplotlib.pyplot as plt | |
plt.figure(figsize=(12, 4)) | |
now = 0 | |
for line in history[-1]["content"].split('\n\n')[-1].split('\n'): | |
try: | |
pitch, duration, wait, velocity, instrument = [int(x) for x in line.split()] | |
except Exception as e: | |
continue | |
plt.plot([now, now+duration], [pitch, pitch], color='black', alpha=1) | |
plt.scatter(now, pitch, s=6, color='black', alpha=0.3) | |
now += wait | |
plt.savefig(f'./temp/{i}.svg') | |
return f'./temp/{i}.mid', f'./temp/{i}.svg' | |
get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False) | |
demo.launch() | |