import random import os import time from queue import Queue from threading import Thread import symusic import transformers import gradio as gr os.makedirs('./temp', exist_ok=True) print('\n\n\n') print('Loading model...') # pipe = transformers.pipeline("text-generation", model="openai-community/gpt2") pipe = transformers.pipeline( "text-generation", model="dx2102/llama-midi", torch_dtype="bfloat16", device_map="auto", ) print('Done') default_prefix = '''pitch duration wait 71 1310 0 48 330 350 55 330 350 64 1310 690 74 660 690 69 1310 0 48 330 350 57 330 350 66 1310 690 67 330 350 69 330 350 71 1310 0 48 330 350 55 330 350 64 1310 690 74 660 690 69 1970 0 48 330 350 ''' default_prefix_len = default_prefix.count('\n') - 2 def postprocess(txt, path): # saves the text representation to a midi file txt = txt.split('\n\n')[-1] notes = [] now = 0 # we need to ignore the invalid output by the model try: for line in txt.split('\n'): pitch, duration, wait = [int(x) for x in line.split()] # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter') notes.append(symusic.core.NoteSecond( time=now/1000, duration=duration/1000, pitch=int(pitch), velocity=80, )) now += wait except Exception as e: print('Ignored error:', e) try: track = symusic.core.TrackSecond() track.notes = symusic.core.NoteSecondList(notes) score = symusic.Score(ttype='Second') score.tracks.append(track) score.dump_midi(path) except Exception as e: print('Ignored error:', e) with gr.Blocks() as demo: chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False) prefix_box = gr.Textbox(value=default_prefix, label="prefix") with gr.Row(): submit_btn = gr.Button("Submit") clear_btn = gr.Button("Clear") with gr.Row(): get_audio_btn = gr.Button("Convert to audio") get_midi_btn = gr.Button("Convert to MIDI") audio_box = gr.Audio() midi_box = gr.File() piano_roll_box = gr.Image() def user_fn(user_message, history: list): return "", history + [{"role": "user", "content": user_message}] def bot_fn(history: list): prefix = history[-1]["content"] history.append({"role": "assistant", "content": ""}) if prefix.startswith("pitch duration wait\n\n"): history[-1]["content"] += "Generating with the given prefix...\n" else: history[-1]["content"] += f"Generating from scratch with a default prefix of {default_prefix_len} notes...\n" prefix = default_prefix queue = Queue() class MyStreamer: def put(self, tokens): for token in tokens.flatten(): text = pipe.tokenizer.decode(token.item()) if text == '<|begin_of_text|>': continue queue.put(text) def end(self): queue.put(None) def background_fn(): result = pipe( prefix, streamer=MyStreamer(), max_length=1000, temperature=0.95, ) print('Generated text:') print(result[0]['generated_text']) print() Thread(target=background_fn).start() while True: text = queue.get() if text is None: break history[-1]["content"] += text yield history prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then( bot_fn, chatbot_box, chatbot_box ) submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then( bot_fn, chatbot_box, chatbot_box ) clear_btn.click(lambda: None, None, chatbot_box, queue=False) def get_audio_fn(history): i = random.randint(0, 1000_000_000) path = f'./temp/{i}.mid' try: postprocess(history[-1]["content"], path) except Exception as e: raise gr.Error(f'Error: {type(e)}, {e}') # turn midi into audio with timidity os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav') # wav to mp3 os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3') return f'./temp/{i}.mp3' get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False) def get_midi_fn(history): i = random.randint(0, 1000_000_000) # turn the text into midi try: postprocess(history[-1]["content"], f'./temp/{i}.mid') except Exception as e: raise gr.Error(f'Error: {type(e)}, {e}') # also render the piano roll import matplotlib.pyplot as plt plt.figure(figsize=(12, 4)) now = 0 for line in history[-1]["content"].split('\n\n')[-1].split('\n'): try: pitch, duration, wait = [int(x) for x in line.split()] except Exception as e: continue plt.plot([now, now+duration], [pitch, pitch], color='black') now += wait plt.savefig(f'./temp/{i}.svg') return f'./temp/{i}.mid', f'./temp/{i}.svg' get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False) demo.launch()