midi-composer / app.py
awacke1's picture
Update app.py
dbc6fc5 verified
raw
history blame
8.95 kB
import random
import gradio as gr
import numpy as np
import rtmidi
import MIDI
import base64
import io
from huggingface_hub import hf_hub_download
from midi_synthesizer import MidiSynthesizer
MAX_SEED = np.iinfo(np.int32).max
class MIDIManager:
def __init__(self):
self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
self.synthesizer = MidiSynthesizer(self.soundfont_path)
self.loaded_midi = {} # Store uploaded MIDI files
self.modified_files = [] # Track generated files
self.is_playing = False
self.midi_in = rtmidi.MidiIn()
self.midi_in.open_port(0) if self.midi_in.get_ports() else None
self.midi_in.set_callback(self.midi_callback)
self.live_notes = []
def midi_callback(self, event, data=None):
message, _ = event
if len(message) >= 3 and message[0] & 0xF0 == 0x90: # Note On
note, velocity = message[1], message[2]
if velocity > 0:
self.live_notes.append((note, velocity, 0)) # Time placeholder
def load_midi(self, file_path):
midi = MIDI.load(file_path)
midi_id = f"midi_{len(self.loaded_midi)}"
self.loaded_midi[midi_id] = midi
return midi_id
def extract_notes(self, midi):
notes = []
for track in midi.tracks:
for event in track.events:
if event.type == 'note_on' and event.velocity > 0:
notes.append((event.note, event.velocity, event.time))
return notes
def generate_variation(self, midi_id, length_factor=2, variation=0.3):
if midi_id not in self.loaded_midi:
return None
notes = self.extract_notes(self.loaded_midi[midi_id])
new_notes = []
for _ in range(int(length_factor)):
for note, vel, time in notes:
if random.random() < variation:
new_note = min(127, max(0, note + random.randint(-2, 2)))
new_vel = min(127, max(0, vel + random.randint(-10, 10)))
new_notes.append((new_note, new_vel, time))
else:
new_notes.append((note, vel, time))
new_midi = MIDI.MIDIFile(1)
new_midi.addTrack()
for note, vel, time in new_notes:
new_midi.addNote(0, 0, note, time, 100, vel)
output = io.BytesIO()
new_midi.writeFile(output)
midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
self.modified_files.append(midi_data)
return midi_data
def apply_synth_effect(self, midi_data, effect, intensity):
midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
if effect == "tempo":
factor = 1 + (intensity - 0.5) * 0.4
for track in midi.tracks:
for event in track.events:
event.time = int(event.time * factor)
output = io.BytesIO()
midi.writeFile(output)
midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
self.modified_files.append(midi_data)
return midi_data
def play_with_loop(self, midi_data):
self.is_playing = True
midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
while self.is_playing:
self.synthesizer.play_midi(midi_file)
return "Stopped"
def stop_playback(self):
self.is_playing = False
return "Stopping..."
def save_live_midi(self):
if not self.live_notes:
return None
midi = MIDI.MIDIFile(1)
midi.addTrack()
time_cum = 0
for note, vel, _ in self.live_notes:
midi.addNote(0, 0, note, time_cum, 100, vel)
time_cum += 100 # Simple timing
output = io.BytesIO()
midi.writeFile(output)
midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
self.modified_files.append(midi_data)
self.live_notes = [] # Reset after saving
return midi_data
midi_manager = MIDIManager()
def create_download_list():
html = "<h3>Downloads</h3><ul>"
for i, data in enumerate(midi_manager.modified_files):
html += f'<li><a href="data:audio/midi;base64,{data}" download="midi_{i}.mid">MIDI {i}</a></li>'
html += "</ul>"
return html
with gr.Blocks(theme=gr.themes.Soft()) as app:
gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>")
with gr.Tabs():
# Tab 1: Load MIDI Files
with gr.Tab("Load MIDI"):
midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
midi_list = gr.State({})
file_display = gr.HTML(value="No files loaded")
def load_files(files):
midi_list_val = {}
html = "<h3>Loaded Files</h3>"
for file in files or []:
midi_id = midi_manager.load_midi(file.name)
midi_list_val[midi_id] = file.name
html += f"<div>{file.name}</div>"
return midi_list_val, html
midi_files.change(load_files, inputs=[midi_files], outputs=[midi_list, file_display])
# Tab 2: Generate & Perform
with gr.Tab("Generate & Perform"):
midi_select = gr.Dropdown(label="Select MIDI", choices=[])
length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor")
variation = gr.Slider(0, 1, value=0.3, label="Variation")
generate_btn = gr.Button("Generate")
effect = gr.Radio(["tempo"], label="Effect", value="tempo")
intensity = gr.Slider(0, 1, value=0.5, label="Intensity")
apply_btn = gr.Button("Apply Effect")
play_btn = gr.Button("Play Loop")
stop_btn = gr.Button("Stop")
output = gr.Audio(label="Preview", type="bytes")
status = gr.Textbox(label="Status", value="Ready")
def update_dropdown(midi_list):
return gr.update(choices=list(midi_list.keys()))
midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select])
def generate(midi_id, length, var):
if not midi_id:
return None, "Select a MIDI file"
midi_data = midi_manager.generate_variation(midi_id, length, var)
return io.BytesIO(base64.b64decode(midi_data)), "Generated"
def apply_effect(midi_data, fx, inten):
if not midi_data:
return None, "Generate a MIDI first"
new_data = midi_manager.apply_synth_effect(midi_data.decode('utf-8'), fx, inten)
return io.BytesIO(base64.b64decode(new_data)), "Effect Applied"
generate_btn.click(generate, inputs=[midi_select, length_factor, variation],
outputs=[output, status])
apply_btn.click(apply_effect, inputs=[output, effect, intensity],
outputs=[output, status])
play_btn.click(midi_manager.play_with_loop, inputs=[output], outputs=[status])
stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[status])
# Tab 3: MIDI Input
with gr.Tab("MIDI Input"):
gr.Markdown("Play your MIDI keyboard to record notes")
save_btn = gr.Button("Save Live MIDI")
live_output = gr.Audio(label="Live MIDI", type="bytes")
def save_live():
midi_data = midi_manager.save_live_midi()
return io.BytesIO(base64.b64decode(midi_data)) if midi_data else None
save_btn.click(save_live, inputs=None, outputs=[live_output])
# Tab 4: Downloads
with gr.Tab("Downloads"):
downloads = gr.HTML(value="No files yet")
def update_downloads(*args):
return create_download_list()
gr.on(triggers=[generate_btn.click, apply_btn.click, save_btn.click],
fn=update_downloads, inputs=None, outputs=[downloads])
gr.Markdown("""
<div style='text-align: center; margin-top: 20px;'>
<img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
<strong>Hugging Face</strong><br>
<a href='https://huggingface.co/models'>Models</a> |
<a href='https://huggingface.co/datasets'>Datasets</a> |
<a href='https://huggingface.co/spaces'>Spaces</a> |
<a href='https://huggingface.co/posts'>Posts</a> |
<a href='https://huggingface.co/docs'>Docs</a> |
<a href='https://huggingface.co/enterprise'>Enterprise</a> |
<a href='https://huggingface.co/pricing'>Pricing</a>
</div>
""")
app.queue().launch(inbrowser=True)