import random import gradio as gr import numpy as np from midi_model import MIDIModel, MIDIModelConfig from midi_synthesizer import MidiSynthesizer import MIDI import base64 import io from huggingface_hub import hf_hub_download MAX_SEED = np.iinfo(np.int32).max class MIDIManager: def __init__(self): self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2") self.synthesizer = MidiSynthesizer(self.soundfont_path) self.loaded_midi = {} # Store multiple MIDI files self.modified_files = [] # Track generated files for download self.is_playing = False def load_midi(self, file_path): midi = MIDI.load(file_path) midi_id = f"midi_{len(self.loaded_midi)}" self.loaded_midi[midi_id] = midi return midi_id, self.extract_notes_and_instruments(midi) def extract_notes_and_instruments(self, midi): notes = [] instruments = set() for track in midi.tracks: for event in track.events: if event.type == 'note_on' and event.velocity > 0: notes.append((event.note, event.velocity, event.time)) if hasattr(event, 'program'): instruments.add(event.program) return notes, list(instruments) def generate_variation(self, midi_id, length_factor=2, variation_level=0.3): if midi_id not in self.loaded_midi: return None original = self.loaded_midi[midi_id] notes, instruments = self.extract_notes_and_instruments(original) # Generate longer sequence new_notes = [] for _ in range(int(length_factor)): for note, vel, time in notes: if random.random() < variation_level: new_note = note + random.randint(-2, 2) new_vel = min(127, max(0, vel + random.randint(-10, 10))) new_notes.append((new_note, new_vel, time)) else: new_notes.append((note, vel, time)) new_midi = MIDI.MIDIFile(1) new_midi.addTrack() for note, vel, time in new_notes: new_midi.addNote(0, 0, note, time, 100, vel) output = io.BytesIO() new_midi.writeFile(output) midi_data = base64.b64encode(output.getvalue()).decode('utf-8') self.modified_files.append(midi_data) return midi_data def apply_synth_effect(self, midi_id, effect_type, intensity): if midi_id not in self.loaded_midi: return None midi = self.loaded_midi[midi_id].copy() # Work on a copy if effect_type == "tempo": factor = 1 + (intensity - 0.5) * 0.4 # -20% to +20% for track in midi.tracks: for event in track.events: event.time = int(event.time * factor) elif effect_type == "pitch": shift = int((intensity - 0.5) * 12) # -6 to +6 semitones for track in midi.tracks: for event in track.events: if hasattr(event, 'note'): event.note = min(127, max(0, event.note + shift)) output = io.BytesIO() midi.writeFile(output) midi_data = base64.b64encode(output.getvalue()).decode('utf-8') self.modified_files.append(midi_data) return midi_data def play_with_loop(self, midi_data): self.is_playing = True midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data))) while self.is_playing: self.synthesizer.play_midi(midi_file) return "Playback stopped" def stop_playback(self): self.is_playing = False return "Stopping playback..." midi_manager = MIDIManager() def create_download_list(modified_files): html = "

Generated MIDI Files

" return html with gr.Blocks(theme=gr.themes.Soft()) as app: gr.Markdown("

🎵 MIDI Sequence Generator & Performer 🎵

") with gr.Tabs(): # Tab 1: MIDI Upload with gr.Tab("Upload MIDI"): midi_files = gr.File(label="Upload MIDI Files", file_count="multiple") midi_list = gr.State({}) file_display = gr.HTML(value="No files loaded") def handle_upload(files): midi_list_val = {} html = "

Loaded MIDI Files

" for file in files: midi_id, (notes, instruments) = midi_manager.load_midi(file.name) midi_list_val[midi_id] = file.name html += f'
{file.name}
' return midi_list_val, html midi_files.change(handle_upload, inputs=[midi_files], outputs=[midi_list, file_display]) # Tab 2: Generate Variations with gr.Tab("Generate"): midi_select = gr.Dropdown(label="Select MIDI", choices=[]) length_factor = gr.Slider(1, 10, value=2, step=1, label="Length Multiplier") variation_level = gr.Slider(0, 1, value=0.3, label="Variation Level") generate_btn = gr.Button("Generate Variation") generated_output = gr.Audio(label="Generated Preview", type="bytes") def update_dropdown(midi_list): return gr.update(choices=list(midi_list.keys())) midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select]) def generate(midi_id, length, variation): if not midi_id: return None midi_data = midi_manager.generate_variation(midi_id, length, variation) return io.BytesIO(base64.b64decode(midi_data)) generate_btn.click(generate, inputs=[midi_select, length_factor, variation_level], outputs=[generated_output]) # Tab 3: Synthesizer Controls with gr.Tab("Perform"): midi_play_select = gr.Dropdown(label="Select MIDI to Play", choices=[]) synth_effects = gr.Radio(["tempo", "pitch"], label="Synth Effect", value="tempo") effect_intensity = gr.Slider(0, 1, value=0.5, label="Effect Intensity") apply_effect_btn = gr.Button("Apply Effect") play_btn = gr.Button("Play with Auto-Loop") stop_btn = gr.Button("Stop") playback_status = gr.Textbox(label="Playback Status", value="Stopped") midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_play_select]) def apply_and_preview(midi_id, effect, intensity): if not midi_id: return None, "No MIDI selected" midi_data = midi_manager.apply_synth_effect(midi_id, effect, intensity) return io.BytesIO(base64.b64decode(midi_data)), "Effect applied" apply_effect_btn.click(apply_and_preview, inputs=[midi_play_select, synth_effects, effect_intensity], outputs=[generated_output, playback_status]) play_btn.click(midi_manager.play_with_loop, inputs=[generated_output], outputs=[playback_status]) stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[playback_status]) # Tab 4: Downloads with gr.Tab("Downloads"): download_list = gr.HTML(value="No generated files yet") def update_downloads(_): return create_download_list(midi_manager.modified_files) gr.on(triggers=[generate_btn.click, apply_effect_btn.click], fn=update_downloads, inputs=None, outputs=[download_list]) # Hugging Face Branding gr.Markdown("""
Hugging Face Logo
Hugging Face
Models | Datasets | Spaces | Posts | Docs | Enterprise | Pricing
""") app.queue().launch(inbrowser=True)