midi-composer / app.py
awacke1's picture
Update app.py
b50d4ec verified
history blame
9.86 kB
import random
import gradio as gr
import numpy as np
import rtmidi
import MIDI
import base64
import io
import os
from huggingface_hub import hf_hub_download
from midi_synthesizer import MidiSynthesizer
MAX_SEED = np.iinfo(np.int32).max
# Example song data (simplified from original)
"title": "Do You Believe in Love",
"progression": ["G", "D", "Em", "C"],
"lyrics": ["I was walking down a one-way street", "Just a-looking for someone to meet"]
class MIDIDeviceManager:
def __init__(self):
self.midiout = rtmidi.MidiOut()
self.midiin = rtmidi.MidiIn()
def get_available_devices(self):
return self.midiout.get_ports() or ["No MIDI devices"]
def get_device_info(self):
devices = self.get_available_devices()
return "\n".join([f"Port {i}: {name}" for i, name in enumerate(devices)]) if devices else "No MIDI devices detected"
class MIDIManager:
def __init__(self):
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont_path)
self.loaded_midi = {} # midi_id: (file_path, midi_obj)
self.modified_files = []
self.is_playing = False
self.example_files = self.load_example_midis()
def load_example_midis(self):
examples = {}
example_dir = "examples"
if os.path.exists(example_dir):
for file in os.listdir(example_dir):
if file.endswith(".mid") or file.endswith(".midi"):
midi_id = f"example_{len(examples)}"
file_path = os.path.join(example_dir, file)
examples[midi_id] = (file_path, MIDI.load(file_path))
if not examples:
midi = MIDI.MIDIFile(1)
midi.addNote(0, 0, 60, 0, 100, 100) # Default C4
examples["example_0"] = ("Simple C4.mid", midi)
return examples
def load_midi(self, file_path):
midi = MIDI.load(file_path)
midi_id = f"midi_{len(self.loaded_midi) - len(self.example_files)}"
self.loaded_midi[midi_id] = (file_path, midi)
return midi_id
def extract_notes_and_instruments(self, midi):
notes = []
instruments = set()
for track in midi.tracks:
for event in track.events:
if event.type == 'note_on' and event.velocity > 0:
notes.append((event.note, event.velocity, event.time))
if hasattr(event, 'program'):
return notes, list(instruments)
def generate_variation(self, midi_id, length_factor=2, variation=0.3):
if midi_id not in self.loaded_midi:
return None
_, midi = self.loaded_midi[midi_id]
notes, instruments = self.extract_notes_and_instruments(midi)
new_notes = []
for _ in range(int(length_factor)):
for note, vel, time in notes:
if random.random() < variation:
new_note = min(127, max(0, note + random.randint(-2, 2)))
new_vel = min(127, max(0, vel + random.randint(-10, 10)))
new_notes.append((new_note, new_vel, time))
new_notes.append((note, vel, time))
new_midi = MIDI.MIDIFile(len(instruments) or 1)
for i, inst in enumerate(instruments or [0]):
new_midi.addProgramChange(i, 0, 0, inst)
for note, vel, time in new_notes:
new_midi.addNote(i, 0, note, time, 100, vel)
output = io.BytesIO()
midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
return midi_data
def apply_synth_effect(self, midi_data, effect, intensity):
midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
if effect == "tempo":
factor = 1 + (intensity - 0.5) * 0.4
for track in midi.tracks:
for event in track.events:
event.time = int(event.time * factor)
output = io.BytesIO()
midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
return midi_data
def play_with_loop(self, midi_data):
self.is_playing = True
midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
while self.is_playing:
return "Stopped"
def stop_playback(self):
self.is_playing = False
return "Stopping..."
midi_manager = MIDIDeviceManager()
midi_processor = MIDIManager()
def create_download_list():
html = "<h3>Downloads</h3><ul>"
for i, data in enumerate(midi_processor.modified_files):
html += f'<li><a href="data:audio/midi;base64,{data}" download="midi_{i}.mid">MIDI {i}</a></li>'
html += "</ul>"
return html
def get_midi_choices():
return [(os.path.basename(path), midi_id) for midi_id, (path, _) in midi_processor.loaded_midi.items()]
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>")
with gr.Tabs():
# Tab 1: Upload MIDI
with gr.Tab("Upload MIDI"):
midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
loaded_display = gr.HTML(value="No files loaded")
output = gr.Audio(label="Generated Preview", type="bytes", autoplay=True)
def load_and_generate(files):
html = "<h3>Loaded Files</h3>"
midi_data = None
for file in files or []:
midi_id = midi_processor.load_midi(file.name)
html += f"<div>{file.name} <button onclick=\"remove_midi('{midi_id}')\">X</button></div>"
midi_data = midi_processor.generate_variation(midi_id) # Auto-generate
return html, (io.BytesIO(base64.b64decode(midi_data)) if midi_data else None), get_midi_choices()
midi_files.change(load_and_generate, inputs=[midi_files],
outputs=[loaded_display, output, gr.State(get_midi_choices())])
# Tab 2: Generate & Perform
with gr.Tab("Generate & Perform"):
midi_select = gr.Dropdown(label="Select MIDI", choices=get_midi_choices(), value=None)
length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor")
variation = gr.Slider(0, 1, value=0.3, label="Variation")
generate_btn = gr.Button("Generate")
effect = gr.Radio(["tempo"], label="Synth Effect", value="tempo")
intensity = gr.Slider(0, 1, value=0.5, label="Effect Intensity")
apply_btn = gr.Button("Apply Effect")
stop_btn = gr.Button("Stop Playback")
output = gr.Audio(label="Preview", type="bytes", autoplay=True)
status = gr.Textbox(label="Status", value="Ready")
midi_device = gr.Dropdown(label="MIDI Output Device", choices=midi_manager.get_available_devices(), type="index")
tempo = gr.Slider(label="Tempo (BPM)", minimum=40, maximum=200, value=120, step=1)
def update_dropdown(choices):
return gr.update(choices=choices)
gr.State(get_midi_choices()).change(update_dropdown, inputs=[gr.State()], outputs=[midi_select])
def generate(midi_id, length, var):
if not midi_id:
return None, "Select a MIDI file"
midi_data = midi_processor.generate_variation(midi_id, length, var)
return io.BytesIO(base64.b64decode(midi_data)), "Playing"
def apply_effect(midi_data, fx, inten):
if not midi_data:
return None, "Generate a MIDI first"
new_data = midi_processor.apply_synth_effect(midi_data.decode('utf-8'), fx, inten)
return io.BytesIO(base64.b64decode(new_data)), "Playing"
generate_btn.click(generate, inputs=[midi_select, length_factor, variation],
outputs=[output, status])
apply_btn.click(apply_effect, inputs=[output, effect, intensity],
outputs=[output, status])
stop_btn.click(midi_processor.stop_playback, inputs=None, outputs=[status])
# Tab 3: Downloads
with gr.Tab("Downloads"):
downloads = gr.HTML(value="No files yet")
def update_downloads(*args):
return create_download_list()
gr.on(triggers=[midi_files.change, generate_btn.click, apply_btn.click],
fn=update_downloads, inputs=None, outputs=[downloads])
<div style='text-align: center; margin-top: 20px;'>
<img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
<strong>Hugging Face</strong><br>
<a href='https://huggingface.co/models'>Models</a> |
<a href='https://huggingface.co/datasets'>Datasets</a> |
<a href='https://huggingface.co/spaces'>Spaces</a> |
<a href='https://huggingface.co/posts'>Posts</a> |
<a href='https://huggingface.co/docs'>Docs</a> |
<a href='https://huggingface.co/enterprise'>Enterprise</a> |
<a href='https://huggingface.co/pricing'>Pricing</a>