Commit 
							
							·
						
						8c8ea80
	
1
								Parent(s):
							
							56ab42f
								
Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -84,7 +84,7 @@ def create_msg(name, data): | |
| 84 | 
             
                return {"name": name, "data": data}
         | 
| 85 |  | 
| 86 |  | 
| 87 | 
            -
            def run( | 
| 88 | 
             
                mid_seq = []
         | 
| 89 | 
             
                gen_events = int(gen_events)
         | 
| 90 | 
             
                max_len = gen_events
         | 
| @@ -92,55 +92,32 @@ def run(model_name, tab, instruments, drum_kit, mid, midi_events, gen_events, te | |
| 92 | 
             
                disable_patch_change = False
         | 
| 93 | 
             
                disable_channels = None
         | 
| 94 | 
             
                if tab == 0:
         | 
| 95 | 
            -
                     | 
| 96 | 
            -
             | 
| 97 | 
            -
                    patches = {}
         | 
| 98 | 
            -
                    for instr in instruments:
         | 
| 99 | 
            -
                        patches[i] = patch2number[instr]
         | 
| 100 | 
            -
                        i = (i + 1) if i != 8 else 10
         | 
| 101 | 
            -
                    if drum_kit != "None":
         | 
| 102 | 
            -
                        patches[9] = drum_kits2number[drum_kit]
         | 
| 103 | 
            -
                    for i, (c, p) in enumerate(patches.items()):
         | 
| 104 | 
            -
                        mid.append(tokenizer.event2tokens(["patch_change", 0, 0, i, c, p]))
         | 
| 105 | 
            -
                    mid_seq = mid
         | 
| 106 | 
            -
                    mid = np.asarray(mid, dtype=np.int64)
         | 
| 107 | 
            -
                    if len(instruments) > 0:
         | 
| 108 | 
            -
                        disable_patch_change = True
         | 
| 109 | 
            -
                        disable_channels = [i for i in range(16) if i not in patches]
         | 
| 110 | 
             
                elif mid is not None:
         | 
| 111 | 
            -
                     | 
| 112 | 
            -
             | 
| 113 | 
            -
                    mid = mid[:int(midi_events)]
         | 
| 114 | 
            -
                    max_len += len(mid)
         | 
| 115 | 
            -
                    for token_seq in mid:
         | 
| 116 | 
            -
                        mid_seq.append(token_seq.tolist())
         | 
| 117 | 
             
                init_msgs = [create_msg("visualizer_clear", None)]
         | 
| 118 | 
             
                for tokens in mid_seq:
         | 
| 119 | 
            -
                    init_msgs.append(create_msg("visualizer_append",  | 
| 120 | 
             
                yield mid_seq, None, None, init_msgs
         | 
| 121 | 
            -
             | 
| 122 | 
            -
                 | 
| 123 | 
            -
             | 
| 124 | 
            -
             | 
| 125 | 
            -
                for i, token_seq in enumerate(generator):
         | 
| 126 | 
            -
                    token_seq = token_seq.tolist()
         | 
| 127 | 
            -
                    mid_seq.append(token_seq)
         | 
| 128 | 
            -
                    event = tokenizer.tokens2event(token_seq)
         | 
| 129 | 
            -
                    yield mid_seq, None, None, [create_msg("visualizer_append", event), create_msg("progress", [i + 1, gen_events])]
         | 
| 130 | 
            -
                mid = tokenizer.detokenize(mid_seq)
         | 
| 131 | 
             
                with open(f"output.mid", 'wb') as f:
         | 
| 132 | 
            -
                    f.write(MIDI.score2midi( | 
| 133 | 
            -
                audio = synthesis(MIDI.score2opus( | 
| 134 | 
             
                yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
         | 
| 135 |  | 
| 136 |  | 
| 137 | 
             
            def cancel_run(mid_seq):
         | 
| 138 | 
             
                if mid_seq is None:
         | 
| 139 | 
             
                    return None, None
         | 
| 140 | 
            -
             | 
| 141 | 
             
                with open(f"output.mid", 'wb') as f:
         | 
| 142 | 
            -
                    f.write(MIDI.score2midi( | 
| 143 | 
            -
                audio = synthesis(MIDI.score2opus( | 
| 144 | 
             
                return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
         | 
| 145 |  | 
| 146 |  | 
| @@ -174,11 +151,6 @@ class JSMsgReceiver(gr.HTML): | |
| 174 | 
             
                def get_block_name(self) -> str:
         | 
| 175 | 
             
                    return "html"
         | 
| 176 |  | 
| 177 | 
            -
            number2drum_kits = {-1: "None", 0: "Standard", 8: "Room", 16: "Power", 24: "Electric", 25: "TR-808", 32: "Jazz",
         | 
| 178 | 
            -
                                40: "Blush", 48: "Orchestra"}
         | 
| 179 | 
            -
            patch2number = {v: k for k, v in MIDI.Number2patch.items()}
         | 
| 180 | 
            -
            drum_kits2number = {v: k for k, v in number2drum_kits.items()}
         | 
| 181 | 
            -
             | 
| 182 | 
             
            if __name__ == "__main__":
         | 
| 183 | 
             
                parser = argparse.ArgumentParser()
         | 
| 184 | 
             
                parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
         | 
| @@ -233,9 +205,7 @@ if __name__ == "__main__": | |
| 233 | 
             
                    output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
         | 
| 234 | 
             
                    output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
         | 
| 235 | 
             
                    output_midi = gr.File(label="output midi", file_types=[".mid"])
         | 
| 236 | 
            -
                    run_event =  | 
| 237 | 
            -
                                                    input_midi_events, input_gen_events, input_temp, input_top_p, input_top_k,
         | 
| 238 | 
            -
                                                    input_allow_cc],
         | 
| 239 | 
             
                                              [output_midi_seq, output_midi, output_audio, js_msg])
         | 
| 240 | 
             
                    stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
         | 
| 241 | 
             
                app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
         | 
|  | |
| 84 | 
             
                return {"name": name, "data": data}
         | 
| 85 |  | 
| 86 |  | 
| 87 | 
            +
            def run(search_prompt):
         | 
| 88 | 
             
                mid_seq = []
         | 
| 89 | 
             
                gen_events = int(gen_events)
         | 
| 90 | 
             
                max_len = gen_events
         | 
|  | |
| 92 | 
             
                disable_patch_change = False
         | 
| 93 | 
             
                disable_channels = None
         | 
| 94 | 
             
                if tab == 0:
         | 
| 95 | 
            +
                    mid_seq = []
         | 
| 96 | 
            +
                   
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 97 | 
             
                elif mid is not None:
         | 
| 98 | 
            +
                    mid_seq = MIDI.midi2score(mid)
         | 
| 99 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
| 100 | 
             
                init_msgs = [create_msg("visualizer_clear", None)]
         | 
| 101 | 
             
                for tokens in mid_seq:
         | 
| 102 | 
            +
                    init_msgs.append(create_msg("visualizer_append", tokens))
         | 
| 103 | 
             
                yield mid_seq, None, None, init_msgs
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                for i in range(len(mid_seq)):
         | 
| 106 | 
            +
                    yield mid_seq, None, None, [create_msg("visualizer_append", mid_seq[i]), create_msg("progress", [i + 1, mid_seq[i+1]])]
         | 
| 107 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 108 | 
             
                with open(f"output.mid", 'wb') as f:
         | 
| 109 | 
            +
                    f.write(MIDI.score2midi(mid_seq))
         | 
| 110 | 
            +
                audio = synthesis(MIDI.score2opus(mid_seq), soundfont_path)
         | 
| 111 | 
             
                yield mid_seq, "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
         | 
| 112 |  | 
| 113 |  | 
| 114 | 
             
            def cancel_run(mid_seq):
         | 
| 115 | 
             
                if mid_seq is None:
         | 
| 116 | 
             
                    return None, None
         | 
| 117 | 
            +
             | 
| 118 | 
             
                with open(f"output.mid", 'wb') as f:
         | 
| 119 | 
            +
                    f.write(MIDI.score2midi(mid_seq))
         | 
| 120 | 
            +
                audio = synthesis(MIDI.score2opus(mid_seq), soundfont_path)
         | 
| 121 | 
             
                return "output.mid", (44100, audio), [create_msg("visualizer_end", None)]
         | 
| 122 |  | 
| 123 |  | 
|  | |
| 151 | 
             
                def get_block_name(self) -> str:
         | 
| 152 | 
             
                    return "html"
         | 
| 153 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 154 | 
             
            if __name__ == "__main__":
         | 
| 155 | 
             
                parser = argparse.ArgumentParser()
         | 
| 156 | 
             
                parser.add_argument("--share", action="store_true", default=False, help="share gradio app")
         | 
|  | |
| 205 | 
             
                    output_midi_visualizer = gr.HTML(elem_id="midi_visualizer_container")
         | 
| 206 | 
             
                    output_audio = gr.Audio(label="output audio", format="mp3", elem_id="midi_audio")
         | 
| 207 | 
             
                    output_midi = gr.File(label="output midi", file_types=[".mid"])
         | 
| 208 | 
            +
                    run_event = search_btn.click(run, [search_prompt],
         | 
|  | |
|  | |
| 209 | 
             
                                              [output_midi_seq, output_midi, output_audio, js_msg])
         | 
| 210 | 
             
                    stop_btn.click(cancel_run, output_midi_seq, [output_midi, output_audio, js_msg], cancels=run_event, queue=False)
         | 
| 211 | 
             
                app.queue(1).launch(server_port=opt.port, share=opt.share, inbrowser=True)
         | 
