Spaces:
Running
on
Zero
Running
on
Zero
| import random | |
| import os | |
| import time | |
| from queue import Queue | |
| from threading import Thread | |
| import symusic | |
| import transformers | |
| import gradio as gr | |
| print('\n\n\n') | |
| print('Loading model...') | |
| # pipe = transformers.pipeline("text-generation", model="openai-community/gpt2") | |
| pipe = transformers.pipeline( | |
| "text-generation", | |
| model="dx2102/llama-midi", | |
| torch_dtype="bfloat16", | |
| device_map="auto", | |
| ) | |
| print('Done') | |
| default_prefix = '''pitch duration wait | |
| 71 1310 0 | |
| 48 330 350 | |
| 55 330 350 | |
| 64 1310 690 | |
| 74 660 690 | |
| 69 1310 0 | |
| 48 330 350 | |
| 57 330 350 | |
| 66 1310 690 | |
| 67 330 350 | |
| 69 330 350 | |
| 71 1310 0 | |
| 48 330 350 | |
| 55 330 350 | |
| 64 1310 690 | |
| 74 660 690 | |
| 69 1970 0 | |
| 48 330 350 | |
| ''' | |
| default_prefix_len = default_prefix.count('\n') - 2 | |
| def postprocess(txt, path): | |
| # saves the text representation to a midi file | |
| txt = txt.split('\n\n')[-1] | |
| notes = [] | |
| now = 0 | |
| # we need to ignore the invalid output by the model | |
| try: | |
| for line in txt.split('\n'): | |
| pitch, duration, wait = [int(x) for x in line.split()] | |
| # Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter') | |
| notes.append(symusic.core.NoteSecond( | |
| time=now/1000, | |
| duration=duration/1000, | |
| pitch=int(pitch), | |
| velocity=80, | |
| )) | |
| now += wait | |
| except Exception as e: | |
| print('Ignored error:', e) | |
| try: | |
| track = symusic.core.TrackSecond() | |
| track.notes = symusic.core.NoteSecondList(notes) | |
| score = symusic.Score(ttype='Second') | |
| score.tracks.append(track) | |
| score.dump_midi(path) | |
| except Exception as e: | |
| print('Ignored error:', e) | |
| with gr.Blocks() as demo: | |
| chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False) | |
| prefix_box = gr.Textbox(value=default_prefix, label="prefix") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Submit") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Row(): | |
| get_audio_btn = gr.Button("Convert to audio") | |
| get_midi_btn = gr.Button("Convert to MIDI") | |
| audio_box = gr.Audio() | |
| midi_box = gr.File() | |
| piano_roll_box = gr.Image() | |
| def user_fn(user_message, history: list): | |
| return "", history + [{"role": "user", "content": user_message}] | |
| def bot_fn(history: list): | |
| prefix = history[-1]["content"] | |
| history.append({"role": "assistant", "content": ""}) | |
| if prefix.startswith("pitch duration wait\n\n"): | |
| history[-1]["content"] += "Generating with the given prefix...\n" | |
| else: | |
| history[-1]["content"] += f"Generating from scratch with a default prefix of {default_prefix_len} notes...\n" | |
| prefix = default_prefix | |
| queue = Queue() | |
| class MyStreamer: | |
| def put(self, tokens): | |
| for token in tokens.flatten(): | |
| text = pipe.tokenizer.decode(token.item()) | |
| if text == '<|begin_of_text|>': | |
| continue | |
| queue.put(text) | |
| def end(self): | |
| queue.put(None) | |
| def background_fn(): | |
| result = pipe( | |
| prefix, | |
| streamer=MyStreamer(), | |
| max_length=1000, | |
| temperature=0.95, | |
| ) | |
| print('Generated text:') | |
| print(result[0]['generated_text']) | |
| print() | |
| Thread(target=background_fn).start() | |
| while True: | |
| text = queue.get() | |
| if text is None: | |
| break | |
| history[-1]["content"] += text | |
| yield history | |
| prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then( | |
| bot_fn, chatbot_box, chatbot_box | |
| ) | |
| submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then( | |
| bot_fn, chatbot_box, chatbot_box | |
| ) | |
| clear_btn.click(lambda: None, None, chatbot_box, queue=False) | |
| def get_audio_fn(history): | |
| i = random.randint(0, 1000_000_000) | |
| path = f'./temp/{i}.mid' | |
| try: | |
| postprocess(history[-1]["content"], path) | |
| except Exception as e: | |
| raise gr.Error(f'Error: {type(e)}, {e}') | |
| # turn midi into audio with timidity | |
| os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav') | |
| # wav to mp3 | |
| os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3') | |
| return f'./temp/{i}.mp3' | |
| get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False) | |
| def get_midi_fn(history): | |
| i = random.randint(0, 1000_000_000) | |
| # turn the text into midi | |
| try: | |
| postprocess(history[-1]["content"], f'./temp/{i}.mid') | |
| except Exception as e: | |
| raise gr.Error(f'Error: {type(e)}, {e}') | |
| # also render the piano roll | |
| import matplotlib.pyplot as plt | |
| plt.figure(figsize=(12, 4)) | |
| now = 0 | |
| for line in history[-1]["content"].split('\n\n')[-1].split('\n'): | |
| try: | |
| pitch, duration, wait = [int(x) for x in line.split()] | |
| except Exception as e: | |
| continue | |
| plt.plot([now, now+duration], [pitch, pitch], color='black') | |
| now += wait | |
| plt.savefig(f'./temp/{i}.svg') | |
| return f'./temp/{i}.mid', f'./temp/{i}.svg' | |
| get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False) | |
| demo.launch() | |