Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,508 Bytes
1db4f6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
import random
import os
import time
from queue import Queue
from threading import Thread
import symusic
import transformers
import gradio as gr
print('\n\n\n')
print('Loading model...')
# pipe = transformers.pipeline("text-generation", model="openai-community/gpt2")
pipe = transformers.pipeline(
"text-generation",
model="dx2102/llama-midi",
torch_dtype="bfloat16",
device_map="auto",
)
print('Done')
default_prefix = '''pitch duration wait
71 1310 0
48 330 350
55 330 350
64 1310 690
74 660 690
69 1310 0
48 330 350
57 330 350
66 1310 690
67 330 350
69 330 350
71 1310 0
48 330 350
55 330 350
64 1310 690
74 660 690
69 1970 0
48 330 350
'''
default_prefix_len = default_prefix.count('\n') - 2
def postprocess(txt, path):
# saves the text representation to a midi file
txt = txt.split('\n\n')[-1]
notes = []
now = 0
# we need to ignore the invalid output by the model
try:
for line in txt.split('\n'):
pitch, duration, wait = [int(x) for x in line.split()]
# Eg. Note(time=7.47, duration=5.25, pitch=43, velocity=64, ttype='Quarter')
notes.append(symusic.core.NoteSecond(
time=now/1000,
duration=duration/1000,
pitch=int(pitch),
velocity=80,
))
now += wait
except Exception as e:
print('Ignored error:', e)
try:
track = symusic.core.TrackSecond()
track.notes = symusic.core.NoteSecondList(notes)
score = symusic.Score(ttype='Second')
score.tracks.append(track)
score.dump_midi(path)
except Exception as e:
print('Ignored error:', e)
with gr.Blocks() as demo:
chatbot_box = gr.Chatbot(type="messages", render_markdown=False, sanitize_html=False)
prefix_box = gr.Textbox(value=default_prefix, label="prefix")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_btn = gr.Button("Clear")
with gr.Row():
get_audio_btn = gr.Button("Convert to audio")
get_midi_btn = gr.Button("Convert to MIDI")
audio_box = gr.Audio()
midi_box = gr.File()
piano_roll_box = gr.Image()
def user_fn(user_message, history: list):
return "", history + [{"role": "user", "content": user_message}]
def bot_fn(history: list):
prefix = history[-1]["content"]
history.append({"role": "assistant", "content": ""})
if prefix.startswith("pitch duration wait\n\n"):
history[-1]["content"] += "Generating with the given prefix...\n"
else:
history[-1]["content"] += f"Generating from scratch with a default prefix of {default_prefix_len} notes...\n"
prefix = default_prefix
queue = Queue()
class MyStreamer:
def put(self, tokens):
for token in tokens.flatten():
text = pipe.tokenizer.decode(token.item())
if text == '<|begin_of_text|>':
continue
queue.put(text)
def end(self):
queue.put(None)
def background_fn():
result = pipe(
prefix,
streamer=MyStreamer(),
max_length=1000,
temperature=0.95,
)
print('Generated text:')
print(result[0]['generated_text'])
print()
Thread(target=background_fn).start()
while True:
text = queue.get()
if text is None:
break
history[-1]["content"] += text
yield history
prefix_box.submit(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
bot_fn, chatbot_box, chatbot_box
)
submit_btn.click(user_fn, [prefix_box, chatbot_box], [prefix_box, chatbot_box], queue=False).then(
bot_fn, chatbot_box, chatbot_box
)
clear_btn.click(lambda: None, None, chatbot_box, queue=False)
def get_audio_fn(history):
i = random.randint(0, 1000_000_000)
path = f'./temp/{i}.mid'
try:
postprocess(history[-1]["content"], path)
except Exception as e:
raise gr.Error(f'Error: {type(e)}, {e}')
# turn midi into audio with timidity
os.system(f'timidity ./temp/{i}.mid -Ow -o ./temp/{i}.wav')
# wav to mp3
os.system(f'lame -b 320 ./temp/{i}.wav ./temp/{i}.mp3')
return f'./temp/{i}.mp3'
get_audio_btn.click(get_audio_fn, chatbot_box, audio_box, queue=False)
def get_midi_fn(history):
i = random.randint(0, 1000_000_000)
# turn the text into midi
try:
postprocess(history[-1]["content"], f'./temp/{i}.mid')
except Exception as e:
raise gr.Error(f'Error: {type(e)}, {e}')
# also render the piano roll
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
now = 0
for line in history[-1]["content"].split('\n\n')[-1].split('\n'):
try:
pitch, duration, wait = [int(x) for x in line.split()]
except Exception as e:
continue
plt.plot([now, now+duration], [pitch, pitch], color='black')
now += wait
plt.savefig(f'./temp/{i}.svg')
return f'./temp/{i}.mid', f'./temp/{i}.svg'
get_midi_btn.click(get_midi_fn, inputs=chatbot_box, outputs=[midi_box, piano_roll_box], queue=False)
demo.launch()
|