midi-composer / app.py
awacke1's picture
Update app.py
e8f8f3c verified
raw
history blame
15.4 kB
import random
import argparse
import os
import glob
import rtmidi
import gradio as gr
import numpy as np
import MIDI
import base64
import io
import soundfile as sf # Placeholder for audio rendering
from huggingface_hub import hf_hub_download
from midi_synthesizer import MidiSynthesizer
MAX_SEED = np.iinfo(np.int32).max
in_space = os.getenv("SYSTEM") == "spaces"
SONG_DATA = {
"title": "Do You Believe in Love",
"progression": ["G", "D", "Em", "C"],
"lyrics": ["I was walking down a one-way street", "Just a-looking for someone to meet"]
}
class MIDIDeviceManager:
def __init__(self):
self.midiout = rtmidi.MidiOut()
self.midiin = rtmidi.MidiIn()
def get_output_devices(self):
return self.midiout.get_ports() or ["No MIDI output devices"]
def get_input_devices(self):
return self.midiin.get_ports() or ["No MIDI input devices"]
def get_device_info(self):
out_devices = self.get_output_devices()
in_devices = self.get_input_devices()
out_info = "\n".join([f"Out Port {i}: {name}" for i, name in enumerate(out_devices)]) if out_devices else "No MIDI output devices detected"
in_info = "\n".join([f"In Port {i}: {name}" for i, name in enumerate(in_devices)]) if in_devices else "No MIDI input devices detected"
return f"Output Devices:\n{out_info}\n\nInput Devices:\n{in_info}"
def close(self):
if self.midiout.is_port_open():
self.midiout.close_port()
if self.midiin.is_port_open():
self.midiin.close_port()
del self.midiout
del self.midiin
class MIDIManager:
def __init__(self):
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont_path)
self.loaded_midi = {} # midi_id: (file_path, midi_obj)
self.modified_files = [] # Stores (midi_base64, audio_base64) tuples
self.is_playing = False
self.instruments = self.random_instrument_set()
self.drum_beat = self.create_drum_beat()
self.starter_midi = self.create_starter_midi()
self.example_files = self.load_example_midis()
self.loaded_midi["starter"] = ("Starter MIDI", self.starter_midi)
self.preload_default_midi()
def random_instrument_set(self):
instrument_pool = [0, 24, 32, 48] # Piano, Guitar, Bass, Strings
return random.sample(instrument_pool, 4)
def create_drum_beat(self):
return [(36, 100, 0), (42, 80, 50), (38, 90, 100), (42, 80, 150)] # Kick, hi-hat, snare, hi-hat
def create_starter_midi(self):
midi = MIDI.MIDIFile(5) # 4 instruments + 1 drum track
for i, inst in enumerate(self.instruments):
midi.addTrack()
midi.addProgramChange(i, 0, 0, inst)
for t in range(0, 400, 100):
note = random.randint(60, 84) # C4 to C6
midi.addNote(i, 0, note, t, 100, 100)
midi.addTrack()
for note, vel, time in self.drum_beat:
midi.addNote(4, 9, note, time, 100, vel)
return midi
def preload_default_midi(self):
default_path = "default.mid"
if os.path.exists(default_path):
midi_id = "default"
midi = MIDI.load(default_path)
self.loaded_midi[midi_id] = (default_path, midi)
midi_data, audio_data = self.generate_variation(midi_id)
self.play_with_loop(midi_data)
def load_example_midis(self):
examples = {}
for file_path in glob.glob("*.mid") + glob.glob("*.midi"):
if file_path == "default.mid":
continue
midi_id = f"example_{len(examples)}"
midi = MIDI.load(file_path)
new_midi = MIDI.MIDIFile(5)
notes, _ = self.extract_notes_and_instruments(midi)
for i, inst in enumerate(self.instruments):
new_midi.addTrack()
new_midi.addProgramChange(i, 0, 0, inst)
for note, vel, time in notes:
new_midi.addNote(i, 0, note, time, 100, vel)
new_midi.addTrack()
for note, vel, time in self.drum_beat:
new_midi.addNote(4, 9, note, time, 100, vel)
examples[midi_id] = (file_path, new_midi)
self.loaded_midi.update(examples)
return examples
def load_midi(self, file_path):
midi = MIDI.load(file_path)
midi_id = f"midi_{len(self.loaded_midi) - len(self.example_files) - 1}"
self.loaded_midi[midi_id] = (file_path, midi)
return midi_id
def extract_notes_and_instruments(self, midi):
notes = []
instruments = set()
for track in midi.tracks:
for event in track.events:
if event.type == 'note_on' and event.velocity > 0:
notes.append((event.note, event.velocity, event.time))
if hasattr(event, 'program'):
instruments.add(event.program)
return notes, list(instruments)
def generate_variation(self, midi_id, length_factor=2, variation=0.3):
if midi_id not in self.loaded_midi:
return None, None
_, midi = self.loaded_midi[midi_id]
notes, instruments = self.extract_notes_and_instruments(midi)
new_notes = []
for _ in range(int(length_factor)):
for note, vel, time in notes:
if random.random() < variation:
new_note = min(127, max(0, note + random.randint(-2, 2)))
new_vel = min(127, max(0, vel + random.randint(-10, 10)))
new_notes.append((new_note, new_vel, time))
else:
new_notes.append((note, vel, time))
new_midi = MIDI.MIDIFile(len(instruments) or 1)
for i, inst in enumerate(instruments or [0]):
new_midi.addTrack()
new_midi.addProgramChange(i, 0, 0, inst)
for note, vel, time in new_notes:
new_midi.addNote(i, 0, note, time, 100, vel)
midi_output = io.BytesIO()
new_midi.writeFile(midi_output)
midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
temp_midi = 'temp.mid'
with open(temp_midi, 'wb') as f:
f.write(midi_output.getvalue())
audio_output = io.BytesIO()
# Placeholder for audio rendering; needs fluidsynth or similar
self.synthesizer.play_midi(new_midi)
audio_data = None # See Notes below
if os.path.exists(temp_midi):
os.remove(temp_midi)
self.modified_files.append((midi_data, audio_data))
return midi_data, audio_data
def apply_synth_effect(self, midi_data, effect, intensity):
midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
if effect == "tempo":
factor = 1 + (intensity - 0.5) * 0.4
for track in midi.tracks:
for event in track.events:
event.time = int(event.time * factor)
midi_output = io.BytesIO()
midi.writeFile(midi_output)
midi_data = base64.b64encode(midi_output.getvalue()).decode('utf-8')
temp_midi = 'temp.mid'
with open(temp_midi, 'wb') as f:
f.write(midi_output.getvalue())
audio_output = io.BytesIO()
self.synthesizer.play_midi(midi)
audio_data = None # Placeholder
if os.path.exists(temp_midi):
os.remove(temp_midi)
self.modified_files.append((midi_data, audio_data))
return midi_data, audio_data
def play_with_loop(self, midi_data):
self.is_playing = True
midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
while self.is_playing:
self.synthesizer.play_midi(midi_file)
return "Stopped"
def stop_playback(self):
self.is_playing = False
return "Stopping..."
def create_download_list():
html = "<h3>Downloads</h3><ul>"
for i, (midi_data, audio_data) in enumerate(midi_processor.modified_files):
html += f'<li><a href="data:audio/midi;base64,{midi_data}" download="midi_{i}.mid">MIDI {i}</a>'
if audio_data:
html += f' | <a href="data:audio/wav;base64,{audio_data}" download="audio_{i}.wav">Audio {i}</a>'
html += '</li>'
html += "</ul>"
return html
def get_midi_choices():
return [(os.path.basename(path), midi_id) for midi_id, (path, _) in midi_processor.loaded_midi.items()]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
parser.add_argument("--batch", type=int, default=1)
opt = parser.parse_args()
midi_manager = MIDIDeviceManager()
midi_processor = MIDIManager()
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>")
with gr.Tabs():
# Tab 1: MIDI Prompt (Main Tab)
with gr.Tab("MIDI Prompt"):
midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
loaded_display = gr.HTML(value="No files loaded")
output = gr.Audio(label="Generated Preview", type="bytes", autoplay=True)
def load_and_generate(files):
html = "<h3>Loaded Files</h3>"
midi_data = None
for file in files or []:
midi_id = midi_processor.load_midi(file.name)
html += f"<div>{file.name} <button onclick=\"remove_midi('{midi_id}')\">X</button></div>"
midi_data, _ = midi_processor.generate_variation(midi_id)
return html, (io.BytesIO(base64.b64decode(midi_data)) if midi_data else None), get_midi_choices(), create_download_list()
midi_files.change(load_and_generate, inputs=[midi_files],
outputs=[loaded_display, output, gr.State(get_midi_choices()), "downloads"])
# Tab 2: Examples
with gr.Tab("Examples"):
example_select = gr.Dropdown(label="Select Example", choices=get_midi_choices(), value=None)
example_output = gr.Audio(label="Example Preview", type="bytes", autoplay=True)
def load_example(midi_id):
if not midi_id:
return None
midi_data, audio_data = midi_processor.generate_variation(midi_id)
midi_processor.play_with_loop(midi_data)
return io.BytesIO(base64.b64decode(midi_data)), create_download_list()
example_select.change(load_example, inputs=[example_select],
outputs=[example_output, "downloads"])
gr.State(get_midi_choices()).change(lambda choices: gr.update(choices=choices),
inputs=[gr.State()], outputs=[example_select])
# Tab 3: Generate & Perform
with gr.Tab("Generate & Perform"):
midi_select = gr.Dropdown(label="Select MIDI", choices=get_midi_choices(), value="starter")
length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor")
variation = gr.Slider(0, 1, value=0.3, label="Variation")
generate_btn = gr.Button("Generate")
effect = gr.Radio(["tempo"], label="Synth Effect", value="tempo")
intensity = gr.Slider(0, 1, value=0.5, label="Effect Intensity")
apply_btn = gr.Button("Apply Effect")
stop_btn = gr.Button("Stop Playback")
output = gr.Audio(label="Preview", type="bytes", autoplay=True)
status = gr.Textbox(label="Status", value="Ready")
midi_device = gr.Dropdown(label="MIDI Output Device", choices=midi_manager.get_output_devices(), type="index")
tempo = gr.Slider(label="Tempo (BPM)", minimum=40, maximum=200, value=120, step=1)
device_info = gr.Textbox(label="Connected MIDI Devices", value=midi_manager.get_device_info(), readonly=True)
refresh_btn = gr.Button("🔄 Refresh MIDI Devices")
def update_dropdown(choices):
return gr.update(choices=choices)
gr.State(get_midi_choices()).change(update_dropdown, inputs=[gr.State()], outputs=[midi_select])
def generate(midi_id, length, var):
if not midi_id:
return None, "Select a MIDI file", create_download_list()
midi_data, audio_data = midi_processor.generate_variation(midi_id, length, var)
midi_processor.play_with_loop(midi_data)
return io.BytesIO(base64.b64decode(midi_data)), "Playing", create_download_list()
def apply_effect(midi_data, fx, inten):
if not midi_data:
return None, "Generate a MIDI first", create_download_list()
new_midi_data, audio_data = midi_processor.apply_synth_effect(midi_data.decode('utf-8'), fx, inten)
midi_processor.play_with_loop(new_midi_data)
return io.BytesIO(base64.b64decode(new_midi_data)), "Playing", create_download_list()
def refresh_devices():
return midi_manager.get_output_devices(), midi_manager.get_device_info()
generate_btn.click(generate, inputs=[midi_select, length_factor, variation],
outputs=[output, status, "downloads"])
apply_btn.click(apply_effect, inputs=[output, effect, intensity],
outputs=[output, status, "downloads"])
stop_btn.click(midi_processor.stop_playback, inputs=None, outputs=[status])
refresh_btn.click(refresh_devices, inputs=None, outputs=[midi_device, device_info])
# Tab 4: Downloads
with gr.Tab("Downloads", elem_id="downloads"):
downloads = gr.HTML(value=create_download_list())
gr.Markdown("""
<div style='text-align: center; margin-top: 20px;'>
<img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
<strong>Hugging Face</strong><br>
<a href='https://huggingface.co/models'>Models</a> |
<a href='https://huggingface.co/datasets'>Datasets</a> |
<a href='https://huggingface.co/spaces'>Spaces</a> |
<a href='https://huggingface.co/posts'>Posts</a> |
<a href='https://huggingface.co/docs'>Docs</a> |
<a href='https://huggingface.co/enterprise'>Enterprise</a> |
<a href='https://huggingface.co/pricing'>Pricing</a>
</div>
""")
app.queue().launch(server_port=opt.port, share=opt.share, inbrowser=True)
midi_manager.close()