awacke1 commited on
Commit
dbc6fc5
·
verified ·
1 Parent(s): 92acab8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -85
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import random
2
  import gradio as gr
3
  import numpy as np
4
- from midi_model import MIDIModel, MIDIModelConfig
5
- from midi_synthesizer import MidiSynthesizer
6
  import MIDI
7
  import base64
8
  import io
9
  from huggingface_hub import hf_hub_download
 
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
@@ -14,39 +14,44 @@ class MIDIManager:
14
  def __init__(self):
15
  self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
16
  self.synthesizer = MidiSynthesizer(self.soundfont_path)
17
- self.loaded_midi = {} # Store multiple MIDI files
18
- self.modified_files = [] # Track generated files for download
19
  self.is_playing = False
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def load_midi(self, file_path):
22
  midi = MIDI.load(file_path)
23
  midi_id = f"midi_{len(self.loaded_midi)}"
24
  self.loaded_midi[midi_id] = midi
25
- return midi_id, self.extract_notes_and_instruments(midi)
26
 
27
- def extract_notes_and_instruments(self, midi):
28
  notes = []
29
- instruments = set()
30
  for track in midi.tracks:
31
  for event in track.events:
32
  if event.type == 'note_on' and event.velocity > 0:
33
  notes.append((event.note, event.velocity, event.time))
34
- if hasattr(event, 'program'):
35
- instruments.add(event.program)
36
- return notes, list(instruments)
37
 
38
- def generate_variation(self, midi_id, length_factor=2, variation_level=0.3):
39
  if midi_id not in self.loaded_midi:
40
  return None
41
- original = self.loaded_midi[midi_id]
42
- notes, instruments = self.extract_notes_and_instruments(original)
43
-
44
- # Generate longer sequence
45
  new_notes = []
46
  for _ in range(int(length_factor)):
47
  for note, vel, time in notes:
48
- if random.random() < variation_level:
49
- new_note = note + random.randint(-2, 2)
50
  new_vel = min(127, max(0, vel + random.randint(-10, 10)))
51
  new_notes.append((new_note, new_vel, time))
52
  else:
@@ -63,21 +68,13 @@ class MIDIManager:
63
  self.modified_files.append(midi_data)
64
  return midi_data
65
 
66
- def apply_synth_effect(self, midi_id, effect_type, intensity):
67
- if midi_id not in self.loaded_midi:
68
- return None
69
- midi = self.loaded_midi[midi_id].copy() # Work on a copy
70
- if effect_type == "tempo":
71
- factor = 1 + (intensity - 0.5) * 0.4 # -20% to +20%
72
  for track in midi.tracks:
73
  for event in track.events:
74
  event.time = int(event.time * factor)
75
- elif effect_type == "pitch":
76
- shift = int((intensity - 0.5) * 12) # -6 to +6 semitones
77
- for track in midi.tracks:
78
- for event in track.events:
79
- if hasattr(event, 'note'):
80
- event.note = min(127, max(0, event.note + shift))
81
  output = io.BytesIO()
82
  midi.writeFile(output)
83
  midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
@@ -89,99 +86,116 @@ class MIDIManager:
89
  midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
90
  while self.is_playing:
91
  self.synthesizer.play_midi(midi_file)
92
- return "Playback stopped"
93
 
94
  def stop_playback(self):
95
  self.is_playing = False
96
- return "Stopping playback..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  midi_manager = MIDIManager()
99
 
100
- def create_download_list(modified_files):
101
- html = "<h3>Generated MIDI Files</h3><ul>"
102
- for i, data in enumerate(modified_files):
103
- html += f'<li><a href="data:audio/midi;base64,{data}" download="generated_midi_{i}.mid">Download MIDI {i}</a></li>'
104
  html += "</ul>"
105
  return html
106
 
107
  with gr.Blocks(theme=gr.themes.Soft()) as app:
108
- gr.Markdown("<h1>🎵 MIDI Sequence Generator & Performer 🎵</h1>")
109
 
110
  with gr.Tabs():
111
- # Tab 1: MIDI Upload
112
- with gr.Tab("Upload MIDI"):
113
  midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
114
  midi_list = gr.State({})
115
  file_display = gr.HTML(value="No files loaded")
116
 
117
- def handle_upload(files):
118
  midi_list_val = {}
119
- html = "<h3>Loaded MIDI Files</h3>"
120
- for file in files:
121
- midi_id, (notes, instruments) = midi_manager.load_midi(file.name)
122
  midi_list_val[midi_id] = file.name
123
- html += f'<div>{file.name} <button onclick="remove_midi(\'{midi_id}\')">X</button></div>'
124
  return midi_list_val, html
125
 
126
- midi_files.change(handle_upload, inputs=[midi_files], outputs=[midi_list, file_display])
127
 
128
- # Tab 2: Generate Variations
129
- with gr.Tab("Generate"):
130
  midi_select = gr.Dropdown(label="Select MIDI", choices=[])
131
- length_factor = gr.Slider(1, 10, value=2, step=1, label="Length Multiplier")
132
- variation_level = gr.Slider(0, 1, value=0.3, label="Variation Level")
133
- generate_btn = gr.Button("Generate Variation")
134
- generated_output = gr.Audio(label="Generated Preview", type="bytes")
 
 
 
 
 
 
135
 
136
  def update_dropdown(midi_list):
137
  return gr.update(choices=list(midi_list.keys()))
138
 
139
  midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select])
140
 
141
- def generate(midi_id, length, variation):
142
  if not midi_id:
143
- return None
144
- midi_data = midi_manager.generate_variation(midi_id, length, variation)
145
- return io.BytesIO(base64.b64decode(midi_data))
146
-
147
- generate_btn.click(generate, inputs=[midi_select, length_factor, variation_level],
148
- outputs=[generated_output])
149
-
150
- # Tab 3: Synthesizer Controls
151
- with gr.Tab("Perform"):
152
- midi_play_select = gr.Dropdown(label="Select MIDI to Play", choices=[])
153
- synth_effects = gr.Radio(["tempo", "pitch"], label="Synth Effect", value="tempo")
154
- effect_intensity = gr.Slider(0, 1, value=0.5, label="Effect Intensity")
155
- apply_effect_btn = gr.Button("Apply Effect")
156
- play_btn = gr.Button("Play with Auto-Loop")
157
- stop_btn = gr.Button("Stop")
158
- playback_status = gr.Textbox(label="Playback Status", value="Stopped")
159
 
160
- midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_play_select])
 
 
 
 
161
 
162
- def apply_and_preview(midi_id, effect, intensity):
163
- if not midi_id:
164
- return None, "No MIDI selected"
165
- midi_data = midi_manager.apply_synth_effect(midi_id, effect, intensity)
166
- return io.BytesIO(base64.b64decode(midi_data)), "Effect applied"
 
 
 
 
 
 
 
167
 
168
- apply_effect_btn.click(apply_and_preview,
169
- inputs=[midi_play_select, synth_effects, effect_intensity],
170
- outputs=[generated_output, playback_status])
171
 
172
- play_btn.click(midi_manager.play_with_loop, inputs=[generated_output],
173
- outputs=[playback_status])
174
- stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[playback_status])
175
 
176
  # Tab 4: Downloads
177
  with gr.Tab("Downloads"):
178
- download_list = gr.HTML(value="No generated files yet")
179
- def update_downloads(_):
180
- return create_download_list(midi_manager.modified_files)
181
- gr.on(triggers=[generate_btn.click, apply_effect_btn.click],
182
- fn=update_downloads, inputs=None, outputs=[download_list])
183
 
184
- # Hugging Face Branding
185
  gr.Markdown("""
186
  <div style='text-align: center; margin-top: 20px;'>
187
  <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
 
1
  import random
2
  import gradio as gr
3
  import numpy as np
4
+ import rtmidi
 
5
  import MIDI
6
  import base64
7
  import io
8
  from huggingface_hub import hf_hub_download
9
+ from midi_synthesizer import MidiSynthesizer
10
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
 
 
14
  def __init__(self):
15
  self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
16
  self.synthesizer = MidiSynthesizer(self.soundfont_path)
17
+ self.loaded_midi = {} # Store uploaded MIDI files
18
+ self.modified_files = [] # Track generated files
19
  self.is_playing = False
20
+ self.midi_in = rtmidi.MidiIn()
21
+ self.midi_in.open_port(0) if self.midi_in.get_ports() else None
22
+ self.midi_in.set_callback(self.midi_callback)
23
+ self.live_notes = []
24
+
25
+ def midi_callback(self, event, data=None):
26
+ message, _ = event
27
+ if len(message) >= 3 and message[0] & 0xF0 == 0x90: # Note On
28
+ note, velocity = message[1], message[2]
29
+ if velocity > 0:
30
+ self.live_notes.append((note, velocity, 0)) # Time placeholder
31
 
32
  def load_midi(self, file_path):
33
  midi = MIDI.load(file_path)
34
  midi_id = f"midi_{len(self.loaded_midi)}"
35
  self.loaded_midi[midi_id] = midi
36
+ return midi_id
37
 
38
+ def extract_notes(self, midi):
39
  notes = []
 
40
  for track in midi.tracks:
41
  for event in track.events:
42
  if event.type == 'note_on' and event.velocity > 0:
43
  notes.append((event.note, event.velocity, event.time))
44
+ return notes
 
 
45
 
46
+ def generate_variation(self, midi_id, length_factor=2, variation=0.3):
47
  if midi_id not in self.loaded_midi:
48
  return None
49
+ notes = self.extract_notes(self.loaded_midi[midi_id])
 
 
 
50
  new_notes = []
51
  for _ in range(int(length_factor)):
52
  for note, vel, time in notes:
53
+ if random.random() < variation:
54
+ new_note = min(127, max(0, note + random.randint(-2, 2)))
55
  new_vel = min(127, max(0, vel + random.randint(-10, 10)))
56
  new_notes.append((new_note, new_vel, time))
57
  else:
 
68
  self.modified_files.append(midi_data)
69
  return midi_data
70
 
71
+ def apply_synth_effect(self, midi_data, effect, intensity):
72
+ midi = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
73
+ if effect == "tempo":
74
+ factor = 1 + (intensity - 0.5) * 0.4
 
 
75
  for track in midi.tracks:
76
  for event in track.events:
77
  event.time = int(event.time * factor)
 
 
 
 
 
 
78
  output = io.BytesIO()
79
  midi.writeFile(output)
80
  midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
 
86
  midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
87
  while self.is_playing:
88
  self.synthesizer.play_midi(midi_file)
89
+ return "Stopped"
90
 
91
  def stop_playback(self):
92
  self.is_playing = False
93
+ return "Stopping..."
94
+
95
+ def save_live_midi(self):
96
+ if not self.live_notes:
97
+ return None
98
+ midi = MIDI.MIDIFile(1)
99
+ midi.addTrack()
100
+ time_cum = 0
101
+ for note, vel, _ in self.live_notes:
102
+ midi.addNote(0, 0, note, time_cum, 100, vel)
103
+ time_cum += 100 # Simple timing
104
+ output = io.BytesIO()
105
+ midi.writeFile(output)
106
+ midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
107
+ self.modified_files.append(midi_data)
108
+ self.live_notes = [] # Reset after saving
109
+ return midi_data
110
 
111
  midi_manager = MIDIManager()
112
 
113
+ def create_download_list():
114
+ html = "<h3>Downloads</h3><ul>"
115
+ for i, data in enumerate(midi_manager.modified_files):
116
+ html += f'<li><a href="data:audio/midi;base64,{data}" download="midi_{i}.mid">MIDI {i}</a></li>'
117
  html += "</ul>"
118
  return html
119
 
120
  with gr.Blocks(theme=gr.themes.Soft()) as app:
121
+ gr.Markdown("<h1>🎵 MIDI Composer 🎵</h1>")
122
 
123
  with gr.Tabs():
124
+ # Tab 1: Load MIDI Files
125
+ with gr.Tab("Load MIDI"):
126
  midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
127
  midi_list = gr.State({})
128
  file_display = gr.HTML(value="No files loaded")
129
 
130
+ def load_files(files):
131
  midi_list_val = {}
132
+ html = "<h3>Loaded Files</h3>"
133
+ for file in files or []:
134
+ midi_id = midi_manager.load_midi(file.name)
135
  midi_list_val[midi_id] = file.name
136
+ html += f"<div>{file.name}</div>"
137
  return midi_list_val, html
138
 
139
+ midi_files.change(load_files, inputs=[midi_files], outputs=[midi_list, file_display])
140
 
141
+ # Tab 2: Generate & Perform
142
+ with gr.Tab("Generate & Perform"):
143
  midi_select = gr.Dropdown(label="Select MIDI", choices=[])
144
+ length_factor = gr.Slider(1, 5, value=2, step=1, label="Length Factor")
145
+ variation = gr.Slider(0, 1, value=0.3, label="Variation")
146
+ generate_btn = gr.Button("Generate")
147
+ effect = gr.Radio(["tempo"], label="Effect", value="tempo")
148
+ intensity = gr.Slider(0, 1, value=0.5, label="Intensity")
149
+ apply_btn = gr.Button("Apply Effect")
150
+ play_btn = gr.Button("Play Loop")
151
+ stop_btn = gr.Button("Stop")
152
+ output = gr.Audio(label="Preview", type="bytes")
153
+ status = gr.Textbox(label="Status", value="Ready")
154
 
155
  def update_dropdown(midi_list):
156
  return gr.update(choices=list(midi_list.keys()))
157
 
158
  midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select])
159
 
160
+ def generate(midi_id, length, var):
161
  if not midi_id:
162
+ return None, "Select a MIDI file"
163
+ midi_data = midi_manager.generate_variation(midi_id, length, var)
164
+ return io.BytesIO(base64.b64decode(midi_data)), "Generated"
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ def apply_effect(midi_data, fx, inten):
167
+ if not midi_data:
168
+ return None, "Generate a MIDI first"
169
+ new_data = midi_manager.apply_synth_effect(midi_data.decode('utf-8'), fx, inten)
170
+ return io.BytesIO(base64.b64decode(new_data)), "Effect Applied"
171
 
172
+ generate_btn.click(generate, inputs=[midi_select, length_factor, variation],
173
+ outputs=[output, status])
174
+ apply_btn.click(apply_effect, inputs=[output, effect, intensity],
175
+ outputs=[output, status])
176
+ play_btn.click(midi_manager.play_with_loop, inputs=[output], outputs=[status])
177
+ stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[status])
178
+
179
+ # Tab 3: MIDI Input
180
+ with gr.Tab("MIDI Input"):
181
+ gr.Markdown("Play your MIDI keyboard to record notes")
182
+ save_btn = gr.Button("Save Live MIDI")
183
+ live_output = gr.Audio(label="Live MIDI", type="bytes")
184
 
185
+ def save_live():
186
+ midi_data = midi_manager.save_live_midi()
187
+ return io.BytesIO(base64.b64decode(midi_data)) if midi_data else None
188
 
189
+ save_btn.click(save_live, inputs=None, outputs=[live_output])
 
 
190
 
191
  # Tab 4: Downloads
192
  with gr.Tab("Downloads"):
193
+ downloads = gr.HTML(value="No files yet")
194
+ def update_downloads(*args):
195
+ return create_download_list()
196
+ gr.on(triggers=[generate_btn.click, apply_btn.click, save_btn.click],
197
+ fn=update_downloads, inputs=None, outputs=[downloads])
198
 
 
199
  gr.Markdown("""
200
  <div style='text-align: center; margin-top: 20px;'>
201
  <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>