Spaces:
Sleeping
Sleeping
| import random | |
| import argparse | |
| import os | |
| import glob | |
| import rtmidi | |
| import gradio as gr | |
| import numpy as np | |
| import MIDI | |
| import base64 | |
| import io | |
| import soundfile as sf # Placeholder for audio rendering | |
| from huggingface_hub import hf_hub_download | |
| from midi_synthesizer import MidiSynthesizer | |
| MAX_SEED = np.iinfo(np.int32).max | |
| in_space = os.getenv("SYSTEM") == "spaces" | |
| SONG_DATA = { | |
| "title": "Do You Believe in Love", | |
| "progression": ["G", "D", "Em", "C"], | |
| "lyrics": ["I was walking down a one-way street", "Just a-looking for someone to meet"] | |
| } | |
| class MIDIDeviceManager: | |
| def __init__(self): | |
| self.midiout = rtmidi.MidiOut() | |
| self.midiin = rtmidi.MidiIn() | |
| def get_output_devices(self): | |
| return self.midiout.get_ports() or ["No MIDI output devices"] | |
| def get_input_devices(self): | |
| return self.midiin.get_ports() or ["No MIDI input devices"] | |
| def get_device_info(self): | |
| out_devices = self.get_output_devices() | |
| in_devices = self.get_input_devices() | |
| out_info = "\n".join([f"Out Port {i}: {name}" for i, name in enumerate(out_devices)]) if out_devices else "No MIDI output devices detected" | |
| in_info = "\n".join([f"In Port {i}: {name}" for i, name in enumerate(in_devices)]) if in_devices else "No MIDI input devices detected" | |
| return f"Output Devices:\n{out_info}\n\nInput Devices:\n{in_info}" | |
| def close(self): | |
| if self.midiout.is_port_open(): | |
| self.midiout.close_port() | |
| if self.midiin.is_port_open(): | |
| self.midiin.close_port() | |
| del self.midiout | |
| del self.midiin | |
| class MIDIManager: | |
| def __init__(self): | |
| self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") | |
| self.synthesizer = MidiSynthesizer(self.soundfont_path) | |
| self.loaded_midi = {} # midi_id: (file_path, midi_obj) | |
| self.modified_files = [] # Stores (midi_base64, audio_base64) tuples | |
| self.is_playing = False | |
| self.instruments = self.random_instrument_set() | |
| self.drum_beat = self.create_drum_beat() | |
| self.starter_midi = self.create_starter_midi() | |
| self.example_files = self.load_example_midis() | |
| self.loaded_midi["starter"] = ("Starter MIDI", self.starter_midi) | |
| self.preload_default_midi() | |
| def random_instrument_set(self): | |
| instrument_pool = [0, 24, 32, 48] # Piano, Guitar, Bass, Strings | |
| return random.sample(instrument_pool, 4) | |
| def create_drum_beat(self): | |
| return [(36, 100, 0), (42, 80, 50), (38, 90, 100), (42, 80, 150)] # Kick, hi-hat, snare, hi-hat | |
| def create_starter_midi(self): | |
| midi = MIDI.MIDIFile(5) # 4 instruments + 1 drum track | |
| for i, inst in enumerate(self.instruments): | |
| midi.addTrack() | |
| midi.addProgramChange(i, 0, 0, inst) | |
| for t in range(0, 400, 100): | |
| note = random.randint(60, 84) # C4 to C6 | |
| midi.addNote(i, 0, note, t, 100, 100) | |
| midi.addTrack() | |
| for note, vel, time in self.drum_beat: | |
| midi.addNote(4, 9, note, time, 100, vel) | |
| return midi | |
| def preload_default_midi(self): | |
| default_path = "default.mid" | |
| if os.path.exists(default_path): | |
| midi_id = "default" | |
| midi = MIDI.load(default_path) | |
| self.loaded_midi[midi_id] = (default_path, midi) | |
| midi_data, audio_data = self.generate_variation(midi_id) | |
| self.play_with_loop(midi_data) | |
| def load_example_midis(self): | |
| examples = {} | |
| for file_path in glob.glob("*.mid") + glob.glob("*.midi"): | |
| if file_path == "default.mid": | |
| continue | |
| midi_id = f"example_{len(examples)}" | |
| midi = MIDI.load(file_path) | |
| new_midi = MIDI.MIDIFile(5) | |
| notes, _ = self.extract_notes_and_instruments(midi) | |
| for i, inst in enumerate(self.instruments): | |
| new_midi.addTrack() | |
| new_midi.addProgramChange(i, 0, 0, inst) | |
| for note, vel, time in notes: | |
| new_midi.addNote(i, 0, note, time, 100, vel) | |
| new_midi.addTrack() | |
| for note, vel, time in self.drum_beat: | |
| new_midi.addNote(4, 9, note, time, 100, vel) | |
| examples[midi_id] = (file_path, new_midi) | |
| self.loaded_midi.update(examples) | |
| return examples | |
| def load_midi(self, file_path): | |
| midi = MIDI.load(file_path) | |
| midi_id = f"midi_{len(self.loaded_midi) - len(self.example_files) - 1}" | |
| self.loaded_midi[midi_id] = (file_path, midi) | |
| return midi_id | |
| def extract_notes_and_instruments(self, midi): | |
| notes = [] | |
| instruments = set() | |
| for track in midi.tracks: | |
| for event in track.events: | |
| if event.type == 'note_on' and event.velocity > 0: | |
| notes.append((event.note, event.velocity, event.time)) | |
| if hasattr(event, 'program'): | |
| instruments.add(event.program) | |
| return notes, list(instruments) | |
| def generate_variation(self, midi_id, length_factor=2, variation=0.3): | |
| if midi_id not in self.loaded_midi: | |
| return None, None | |
| _, midi = self.loaded_midi[midi_id] | |
| notes, instruments = self.extract_notes_and_instruments(midi) | |
| new_notes = [] | |
| for _ in range(int(length_factor)): | |
| for note, vel, time in notes: | |
| if random.random() < variation: | |
| new_note = min(127, max(0, note + random.randint(-2, 2))) | |
| new_vel = min(127, max(0, vel + random.randint(-10, 10))) | |
| new_notes.append((new_note, new_vel, time)) | |
| else: | |
| new_notes.append((note, vel, time)) | |
| new_midi = MIDI.MIDIFile(len(instruments) or 1) | |
| for i, inst in enumerate(instruments or [0]): | |
| new_midi.addTrack() | |
| new_midi.addProgramChange(i, 0, 0, inst) | |
| for note, vel, time in new_notes: | |
| new_midi.addNote(i, 0, note, time, 100, vel) | |
| midi_output = io.BytesIO() | |
| new_midi.writeFile(midi_output) | |
| midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') | |
| temp_midi = 'temp.mid' | |
| with open(temp_midi, 'wb') as f: | |
| f.write(midi_output.getvalue()) | |
| audio_output = io.BytesIO() | |
| # Placeholder for audio rendering; needs fluidsynth or similar | |
| self.synthesizer.play_midi(new_midi) | |
| audio_data = None # See Notes below | |
| if os.path.exists(temp_midi): | |
| os.remove(temp_midi) | |
| self.modified_files.append((midi_data, audio_data)) | |
| return midi_data, audio_data | |
| def apply_synth_effect(self, midi_data, effect, intensity): | |
| midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) | |
| if effect == "tempo": | |
| factor = 1 + (intensity - 0.5) * 0.4 | |
| for track in midi.tracks: | |
| for event in track.events: | |
| event.time = int(event.time * factor) | |
| midi_output = io.BytesIO() | |
| midi.writeFile(midi_output) | |
| midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8') | |
| temp_midi = 'temp.mid' | |
| with open(temp_midi, 'wb') as f: | |
| f.write(midi_output.getvalue()) | |
| audio_output = io.BytesIO() | |
| self.synthesizer.play_midi(midi) | |
| audio_data = None # Placeholder | |
| if os.path.exists(temp_midi): | |
| os.remove(temp_midi) | |
| self.modified_files.append((midi_data, audio_data)) | |
| return midi_data, audio_data | |
| def play_with_loop(self, midi_data): | |
| self.is_playing = True | |
| midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) | |
| while self.is_playing: | |
| self.synthesizer.play_midi(midi_file) | |
| return "Stopped" | |
| def stop_playback(self): | |
| self.is_playing = False | |
| return "Stopping..." | |
| def create_download_list(): | |
| html = "<h3>Downloads</h3><ul>" | |
| for i, (midi_data, audio_data) in enumerate(midi_processor.modified_files): | |
| html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="midi_{i}.mid">MIDI {i}</a>' | |
| if audio_data: | |
| html += f' | <a href="data:audio/wav;base64,{audio_data}" download="audio_{i}.wav">Audio {i}</a>' | |
| html += '</li>' | |
| html += "</ul>" | |
| return html | |
| def get_midi_choices(): | |
| return [(os.path.basename(path), midi_id) for midi_id, (path, _) in midi_processor.loaded_midi.items()] | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--port", type=int, default=7860) | |
| parser.add_argument("--share", action="store_true") | |
| parser.add_argument("--batch", type=int, default=1) | |
| opt = parser.parse_args() | |
| midi_manager = MIDIDeviceManager() | |
| midi_processor = MIDIManager() | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>") | |
| with gr.Tabs(): | |
| # Tab 1: MIDI Prompt (Main Tab) | |
| with gr.Tab("MIDI Prompt"): | |
| midi_files = gr.File(label="Upload MIDI Files", file_count="multiple") | |
| loaded_display = gr.HTML(value="No files loaded") | |
| output = gr.Audio(label="Generated Preview", type="bytes", autoplay=True) | |
| def load_and_generate(files): | |
| html = "<h3>Loaded Files</h3>" | |
| midi_data = None | |
| for file in files or []: | |
| midi_id = midi_processor.load_midi(file.name) | |
| html += f"<div>{file.name} <button onclick=\"remove_midi('{midi_id}')\">X</button></div>" | |
| midi_data, _ = midi_processor.generate_variation(midi_id) | |
| return html, (io.BytesIO(base64.b64decode(midi_data)) if midi_data else None), get_midi_choices(), create_download_list() | |
| midi_files.change(load_and_generate, inputs=[midi_files], | |
| outputs=[loaded_display, output, gr.State(get_midi_choices()), "downloads"]) | |
| # Tab 2: Examples | |
| with gr.Tab("Examples"): | |
| example_select = gr.Dropdown(label="Select Example", choices=get_midi_choices(), value=None) | |
| example_output = gr.Audio(label="Example Preview", type="bytes", autoplay=True) | |
| def load_example(midi_id): | |
| if not midi_id: | |
| return None | |
| midi_data, audio_data = midi_processor.generate_variation(midi_id) | |
| midi_processor.play_with_loop(midi_data) | |
| return io.BytesIO(base64.b64decode(midi_data)), create_download_list() | |
| example_select.change(load_example, inputs=[example_select], | |
| outputs=[example_output, "downloads"]) | |
| gr.State(get_midi_choices()).change(lambda choices: gr.update(choices=choices), | |
| inputs=[gr.State()], outputs=[example_select]) | |
| # Tab 3: Generate & Perform | |
| with gr.Tab("Generate & Perform"): | |
| midi_select = gr.Dropdown(label="Select MIDI", choices=get_midi_choices(), value="starter") | |
| length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor") | |
| variation = gr.Slider(0, 1, value=0.3, label="Variation") | |
| generate_btn = gr.Button("Generate") | |
| effect = gr.Radio(["tempo"], label="Synth Effect", value="tempo") | |
| intensity = gr.Slider(0, 1, value=0.5, label="Effect Intensity") | |
| apply_btn = gr.Button("Apply Effect") | |
| stop_btn = gr.Button("Stop Playback") | |
| output = gr.Audio(label="Preview", type="bytes", autoplay=True) | |
| status = gr.Textbox(label="Status", value="Ready") | |
| midi_device = gr.Dropdown(label="MIDI Output Device", choices=midi_manager.get_output_devices(), type="index") | |
| tempo = gr.Slider(label="Tempo (BPM)", minimum=40, maximum=200, value=120, step=1) | |
| device_info = gr.Textbox(label="Connected MIDI Devices", value=midi_manager.get_device_info(), readonly=True) | |
| refresh_btn = gr.Button("🔄 Refresh MIDI Devices") | |
| def update_dropdown(choices): | |
| return gr.update(choices=choices) | |
| gr.State(get_midi_choices()).change(update_dropdown, inputs=[gr.State()], outputs=[midi_select]) | |
| def generate(midi_id, length, var): | |
| if not midi_id: | |
| return None, "Select a MIDI file", create_download_list() | |
| midi_data, audio_data = midi_processor.generate_variation(midi_id, length, var) | |
| midi_processor.play_with_loop(midi_data) | |
| return io.BytesIO(base64.b64decode(midi_data)), "Playing", create_download_list() | |
| def apply_effect(midi_data, fx, inten): | |
| if not midi_data: | |
| return None, "Generate a MIDI first", create_download_list() | |
| new_midi_data, audio_data = midi_processor.apply_synth_effect(midi_data.decode('utf-8'), fx, inten) | |
| midi_processor.play_with_loop(new_midi_data) | |
| return io.BytesIO(base64.b64decode(new_midi_data)), "Playing", create_download_list() | |
| def refresh_devices(): | |
| return midi_manager.get_output_devices(), midi_manager.get_device_info() | |
| generate_btn.click(generate, inputs=[midi_select, length_factor, variation], | |
| outputs=[output, status, "downloads"]) | |
| apply_btn.click(apply_effect, inputs=[output, effect, intensity], | |
| outputs=[output, status, "downloads"]) | |
| stop_btn.click(midi_processor.stop_playback, inputs=None, outputs=[status]) | |
| refresh_btn.click(refresh_devices, inputs=None, outputs=[midi_device, device_info]) | |
| # Tab 4: Downloads | |
| with gr.Tab("Downloads", elem_id="downloads"): | |
| downloads = gr.HTML(value=create_download_list()) | |
| gr.Markdown(""" | |
| <div style='text-align: center; margin-top: 20px;'> | |
| <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br> | |
| <strong>Hugging Face</strong><br> | |
| <a href='https://huggingface.co/models'>Models</a> | | |
| <a href='https://huggingface.co/datasets'>Datasets</a> | | |
| <a href='https://huggingface.co/spaces'>Spaces</a> | | |
| <a href='https://huggingface.co/posts'>Posts</a> | | |
| <a href='https://huggingface.co/docs'>Docs</a> | | |
| <a href='https://huggingface.co/enterprise'>Enterprise</a> | | |
| <a href='https://huggingface.co/pricing'>Pricing</a> | |
| </div> | |
| """) | |
| app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True) | |
| midi_manager.close() |