awacke1 commited on
Commit
92acab8
·
verified ·
1 Parent(s): e1c0e3a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -0
app.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
+ import numpy as np
4
+ from midi_model import MIDIModel, MIDIModelConfig
5
+ from midi_synthesizer import MidiSynthesizer
6
+ import MIDI
7
+ import base64
8
+ import io
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ MAX_SEED = np.iinfo(np.int32).max
12
+
13
+ class MIDIManager:
14
+ def __init__(self):
15
+ self.soundfont_path = hf_hub_download(repo_id="skytnt/midi-model", filename="soundfont.sf2")
16
+ self.synthesizer = MidiSynthesizer(self.soundfont_path)
17
+ self.loaded_midi = {} # Store multiple MIDI files
18
+ self.modified_files = [] # Track generated files for download
19
+ self.is_playing = False
20
+
21
+ def load_midi(self, file_path):
22
+ midi = MIDI.load(file_path)
23
+ midi_id = f"midi_{len(self.loaded_midi)}"
24
+ self.loaded_midi[midi_id] = midi
25
+ return midi_id, self.extract_notes_and_instruments(midi)
26
+
27
+ def extract_notes_and_instruments(self, midi):
28
+ notes = []
29
+ instruments = set()
30
+ for track in midi.tracks:
31
+ for event in track.events:
32
+ if event.type == 'note_on' and event.velocity > 0:
33
+ notes.append((event.note, event.velocity, event.time))
34
+ if hasattr(event, 'program'):
35
+ instruments.add(event.program)
36
+ return notes, list(instruments)
37
+
38
+ def generate_variation(self, midi_id, length_factor=2, variation_level=0.3):
39
+ if midi_id not in self.loaded_midi:
40
+ return None
41
+ original = self.loaded_midi[midi_id]
42
+ notes, instruments = self.extract_notes_and_instruments(original)
43
+
44
+ # Generate longer sequence
45
+ new_notes = []
46
+ for _ in range(int(length_factor)):
47
+ for note, vel, time in notes:
48
+ if random.random() < variation_level:
49
+ new_note = note + random.randint(-2, 2)
50
+ new_vel = min(127, max(0, vel + random.randint(-10, 10)))
51
+ new_notes.append((new_note, new_vel, time))
52
+ else:
53
+ new_notes.append((note, vel, time))
54
+
55
+ new_midi = MIDI.MIDIFile(1)
56
+ new_midi.addTrack()
57
+ for note, vel, time in new_notes:
58
+ new_midi.addNote(0, 0, note, time, 100, vel)
59
+
60
+ output = io.BytesIO()
61
+ new_midi.writeFile(output)
62
+ midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
63
+ self.modified_files.append(midi_data)
64
+ return midi_data
65
+
66
+ def apply_synth_effect(self, midi_id, effect_type, intensity):
67
+ if midi_id not in self.loaded_midi:
68
+ return None
69
+ midi = self.loaded_midi[midi_id].copy() # Work on a copy
70
+ if effect_type == "tempo":
71
+ factor = 1 + (intensity - 0.5) * 0.4 # -20% to +20%
72
+ for track in midi.tracks:
73
+ for event in track.events:
74
+ event.time = int(event.time * factor)
75
+ elif effect_type == "pitch":
76
+ shift = int((intensity - 0.5) * 12) # -6 to +6 semitones
77
+ for track in midi.tracks:
78
+ for event in track.events:
79
+ if hasattr(event, 'note'):
80
+ event.note = min(127, max(0, event.note + shift))
81
+ output = io.BytesIO()
82
+ midi.writeFile(output)
83
+ midi_data = base64.b64encode(output.getvalue()).decode('utf-8')
84
+ self.modified_files.append(midi_data)
85
+ return midi_data
86
+
87
+ def play_with_loop(self, midi_data):
88
+ self.is_playing = True
89
+ midi_file = MIDI.load(io.BytesIO(base64.b64decode(midi_data)))
90
+ while self.is_playing:
91
+ self.synthesizer.play_midi(midi_file)
92
+ return "Playback stopped"
93
+
94
+ def stop_playback(self):
95
+ self.is_playing = False
96
+ return "Stopping playback..."
97
+
98
+ midi_manager = MIDIManager()
99
+
100
+ def create_download_list(modified_files):
101
+ html = "<h3>Generated MIDI Files</h3><ul>"
102
+ for i, data in enumerate(modified_files):
103
+ html += f'<li><a href="data:audio/midi;base64,{data}" download="generated_midi_{i}.mid">Download MIDI {i}</a></li>'
104
+ html += "</ul>"
105
+ return html
106
+
107
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
108
+ gr.Markdown("<h1>🎵 MIDI Sequence Generator & Performer 🎵</h1>")
109
+
110
+ with gr.Tabs():
111
+ # Tab 1: MIDI Upload
112
+ with gr.Tab("Upload MIDI"):
113
+ midi_files = gr.File(label="Upload MIDI Files", file_count="multiple")
114
+ midi_list = gr.State({})
115
+ file_display = gr.HTML(value="No files loaded")
116
+
117
+ def handle_upload(files):
118
+ midi_list_val = {}
119
+ html = "<h3>Loaded MIDI Files</h3>"
120
+ for file in files:
121
+ midi_id, (notes, instruments) = midi_manager.load_midi(file.name)
122
+ midi_list_val[midi_id] = file.name
123
+ html += f'<div>{file.name} <button onclick="remove_midi(\'{midi_id}\')">X</button></div>'
124
+ return midi_list_val, html
125
+
126
+ midi_files.change(handle_upload, inputs=[midi_files], outputs=[midi_list, file_display])
127
+
128
+ # Tab 2: Generate Variations
129
+ with gr.Tab("Generate"):
130
+ midi_select = gr.Dropdown(label="Select MIDI", choices=[])
131
+ length_factor = gr.Slider(1, 10, value=2, step=1, label="Length Multiplier")
132
+ variation_level = gr.Slider(0, 1, value=0.3, label="Variation Level")
133
+ generate_btn = gr.Button("Generate Variation")
134
+ generated_output = gr.Audio(label="Generated Preview", type="bytes")
135
+
136
+ def update_dropdown(midi_list):
137
+ return gr.update(choices=list(midi_list.keys()))
138
+
139
+ midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_select])
140
+
141
+ def generate(midi_id, length, variation):
142
+ if not midi_id:
143
+ return None
144
+ midi_data = midi_manager.generate_variation(midi_id, length, variation)
145
+ return io.BytesIO(base64.b64decode(midi_data))
146
+
147
+ generate_btn.click(generate, inputs=[midi_select, length_factor, variation_level],
148
+ outputs=[generated_output])
149
+
150
+ # Tab 3: Synthesizer Controls
151
+ with gr.Tab("Perform"):
152
+ midi_play_select = gr.Dropdown(label="Select MIDI to Play", choices=[])
153
+ synth_effects = gr.Radio(["tempo", "pitch"], label="Synth Effect", value="tempo")
154
+ effect_intensity = gr.Slider(0, 1, value=0.5, label="Effect Intensity")
155
+ apply_effect_btn = gr.Button("Apply Effect")
156
+ play_btn = gr.Button("Play with Auto-Loop")
157
+ stop_btn = gr.Button("Stop")
158
+ playback_status = gr.Textbox(label="Playback Status", value="Stopped")
159
+
160
+ midi_list.change(update_dropdown, inputs=[midi_list], outputs=[midi_play_select])
161
+
162
+ def apply_and_preview(midi_id, effect, intensity):
163
+ if not midi_id:
164
+ return None, "No MIDI selected"
165
+ midi_data = midi_manager.apply_synth_effect(midi_id, effect, intensity)
166
+ return io.BytesIO(base64.b64decode(midi_data)), "Effect applied"
167
+
168
+ apply_effect_btn.click(apply_and_preview,
169
+ inputs=[midi_play_select, synth_effects, effect_intensity],
170
+ outputs=[generated_output, playback_status])
171
+
172
+ play_btn.click(midi_manager.play_with_loop, inputs=[generated_output],
173
+ outputs=[playback_status])
174
+ stop_btn.click(midi_manager.stop_playback, inputs=None, outputs=[playback_status])
175
+
176
+ # Tab 4: Downloads
177
+ with gr.Tab("Downloads"):
178
+ download_list = gr.HTML(value="No generated files yet")
179
+ def update_downloads(_):
180
+ return create_download_list(midi_manager.modified_files)
181
+ gr.on(triggers=[generate_btn.click, apply_effect_btn.click],
182
+ fn=update_downloads, inputs=None, outputs=[download_list])
183
+
184
+ # Hugging Face Branding
185
+ gr.Markdown("""
186
+ <div style='text-align: center; margin-top: 20px;'>
187
+ <img src='https://huggingface.co/front/assets/huggingface_logo-noborder.svg' alt='Hugging Face Logo' style='width: 50px;'><br>
188
+ <strong>Hugging Face</strong><br>
189
+ <a href='https://huggingface.co/models'>Models</a> |
190
+ <a href='https://huggingface.co/datasets'>Datasets</a> |
191
+ <a href='https://huggingface.co/spaces'>Spaces</a> |
192
+ <a href='https://huggingface.co/posts'>Posts</a> |
193
+ <a href='https://huggingface.co/docs'>Docs</a> |
194
+ <a href='https://huggingface.co/enterprise'>Enterprise</a> |
195
+ <a href='https://huggingface.co/pricing'>Pricing</a>
196
+ </div>
197
+ """)
198
+
199
+ app.queue().launch(inbrowser=True)