admin commited on
Commit
1aa8b04
Β·
1 Parent(s): f2b9c94
Files changed (4) hide show
  1. app.py +193 -439
  2. generate.py +290 -0
  3. requirements.txt +0 -1
  4. utils.py +8 -3
app.py CHANGED
@@ -1,296 +1,16 @@
1
- import re
2
  import os
3
  import json
4
- import time
5
- import torch
6
- import random
7
  import shutil
8
  import argparse
9
  import warnings
10
  import gradio as gr
11
- import soundfile as sf
12
- from transformers import GPT2Config
13
- from model import Patchilizer, TunesFormer
14
- from convert import abc2xml, xml2img, xml2, transpose_octaves_abc
15
- from utils import (
16
- PATCH_NUM_LAYERS,
17
- PATCH_LENGTH,
18
- CHAR_NUM_LAYERS,
19
- PATCH_SIZE,
20
- SHARE_WEIGHTS,
21
- TEMP_DIR,
22
- WEIGHTS_DIR,
23
- DEVICE,
24
- )
25
-
26
-
27
- def get_args(parser: argparse.ArgumentParser):
28
- parser.add_argument(
29
- "-num_tunes",
30
- type=int,
31
- default=1,
32
- help="the number of independently computed returned tunes",
33
- )
34
- parser.add_argument(
35
- "-max_patch",
36
- type=int,
37
- default=128,
38
- help="integer to define the maximum length in tokens of each tune",
39
- )
40
- parser.add_argument(
41
- "-top_p",
42
- type=float,
43
- default=0.8,
44
- help="float to define the tokens that are within the sample operation of text generation",
45
- )
46
- parser.add_argument(
47
- "-top_k",
48
- type=int,
49
- default=8,
50
- help="integer to define the tokens that are within the sample operation of text generation",
51
- )
52
- parser.add_argument(
53
- "-temperature",
54
- type=float,
55
- default=1.2,
56
- help="the temperature of the sampling operation",
57
- )
58
- parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
59
- parser.add_argument(
60
- "-show_control_code",
61
- type=bool,
62
- default=False,
63
- help="whether to show control code",
64
- )
65
- parser.add_argument(
66
- "-template",
67
- type=bool,
68
- default=True,
69
- help="whether to generate by template",
70
- )
71
- return parser.parse_args()
72
-
73
-
74
- def get_abc_key_val(text: str, key="K"):
75
- pattern = re.escape(key) + r":(.*?)\n"
76
- match = re.search(pattern, text)
77
- if match:
78
- return match.group(1).strip()
79
- else:
80
- return None
81
-
82
-
83
- def adjust_volume(in_audio: str, dB_change: int):
84
- y, sr = sf.read(in_audio)
85
- sf.write(in_audio, y * 10 ** (dB_change / 20), sr)
86
-
87
-
88
- def generate_music(
89
- args,
90
- emo: str,
91
- weights: str,
92
- outdir=TEMP_DIR,
93
- fix_tempo=None,
94
- fix_pitch=None,
95
- fix_volume=None,
96
- ):
97
- patchilizer = Patchilizer()
98
- patch_config = GPT2Config(
99
- num_hidden_layers=PATCH_NUM_LAYERS,
100
- max_length=PATCH_LENGTH,
101
- max_position_embeddings=PATCH_LENGTH,
102
- vocab_size=1,
103
- )
104
- char_config = GPT2Config(
105
- num_hidden_layers=CHAR_NUM_LAYERS,
106
- max_length=PATCH_SIZE,
107
- max_position_embeddings=PATCH_SIZE,
108
- vocab_size=128,
109
- )
110
- model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
111
- checkpoint = torch.load(weights, map_location=DEVICE)
112
- model.load_state_dict(checkpoint["model"])
113
- model = model.to(DEVICE)
114
- model.eval()
115
- prompt = f"A:{emo}\n"
116
- tunes = ""
117
- num_tunes = args.num_tunes
118
- max_patch = args.max_patch
119
- top_p = args.top_p
120
- top_k = args.top_k
121
- temperature = args.temperature
122
- seed = args.seed
123
- show_control_code = args.show_control_code
124
- fname_prefix = emo if args.template else "Melody"
125
- print(" Hyper parms ".center(60, "#"), "\n")
126
- args_dict: dict = vars(args)
127
- for arg in args_dict.keys():
128
- print(f"{arg}: {str(args_dict[arg])}")
129
-
130
- print("\n", " Output tunes ".center(60, "#"))
131
- start_time = time.time()
132
- for i in range(num_tunes):
133
- title = f"T:{fname_prefix} Fragment\n"
134
- artist = f"C:Generated by AI\n"
135
- tune = f"X:{str(i + 1)}\n{title}{artist}{prompt}"
136
- lines = re.split(r"(\n)", tune)
137
- tune = ""
138
- skip = False
139
- for line in lines:
140
- if show_control_code or line[:2] not in ["S:", "B:", "E:", "D:"]:
141
- if not skip:
142
- print(line, end="")
143
- tune += line
144
-
145
- skip = False
146
-
147
- else:
148
- skip = True
149
-
150
- input_patches = torch.tensor(
151
- [patchilizer.encode(prompt, add_special_patches=True)[:-1]],
152
- device=DEVICE,
153
- )
154
- if tune == "":
155
- tokens = None
156
-
157
- else:
158
- prefix = patchilizer.decode(input_patches[0])
159
- remaining_tokens = prompt[len(prefix) :]
160
- tokens = torch.tensor(
161
- [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
162
- device=DEVICE,
163
- )
164
-
165
- while input_patches.shape[1] < max_patch:
166
- predicted_patch, seed = model.generate(
167
- input_patches,
168
- tokens,
169
- top_p=top_p,
170
- top_k=top_k,
171
- temperature=temperature,
172
- seed=seed,
173
- )
174
- tokens = None
175
- if predicted_patch[0] != patchilizer.eos_token_id:
176
- next_bar = patchilizer.decode([predicted_patch])
177
- if show_control_code or next_bar[:2] not in ["S:", "B:", "E:", "D:"]:
178
- print(next_bar, end="")
179
- tune += next_bar
180
-
181
- if next_bar == "":
182
- break
183
-
184
- next_bar = remaining_tokens + next_bar
185
- remaining_tokens = ""
186
- predicted_patch = torch.tensor(
187
- patchilizer.bar2patch(next_bar),
188
- device=DEVICE,
189
- ).unsqueeze(0)
190
- input_patches = torch.cat(
191
- [input_patches, predicted_patch.unsqueeze(0)],
192
- dim=1,
193
- )
194
-
195
- else:
196
- break
197
-
198
- tunes += f"{tune}\n\n"
199
- print("\n")
200
-
201
- # fix tempo
202
- if fix_tempo != None:
203
- tempo = f"Q:{fix_tempo}\n"
204
-
205
- else:
206
- tempo = f"Q:{random.randint(88, 132)}\n"
207
- if emo == "Q1":
208
- tempo = f"Q:{random.randint(160, 184)}\n"
209
- elif emo == "Q2":
210
- tempo = f"Q:{random.randint(184, 228)}\n"
211
- elif emo == "Q3":
212
- tempo = f"Q:{random.randint(40, 69)}\n"
213
- elif emo == "Q4":
214
- tempo = f"Q:{random.randint(40, 69)}\n"
215
-
216
- Q_val = get_abc_key_val(tunes, "Q")
217
- if Q_val:
218
- tunes = tunes.replace(f"Q:{Q_val}\n", "")
219
-
220
- K_val = get_abc_key_val(tunes)
221
- if K_val == "none":
222
- K_val = "C"
223
- tunes = tunes.replace("K:none\n", f"K:{K_val}\n")
224
-
225
- tunes = tunes.replace(f"A:{emo}\n", tempo)
226
- # fix mode:major/minor
227
- mode = "major" if emo == "Q1" or emo == "Q4" else "minor"
228
- if (mode == "major") and ("m" in K_val):
229
- tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n")
230
-
231
- elif (mode == "minor") and (not "m" in K_val):
232
- tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.replace('dor', '')}min\n")
233
-
234
- print("Generation time: {:.2f} seconds".format(time.time() - start_time))
235
- timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
236
- try:
237
- # fix avg_pitch (octave)
238
- if fix_pitch != None:
239
- if fix_pitch:
240
- tunes, xml = transpose_octaves_abc(
241
- tunes,
242
- f"{outdir}/{timestamp}.musicxml",
243
- fix_pitch,
244
- )
245
- tunes = tunes.replace(title + title, title)
246
- os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
247
- xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml"
248
-
249
- else:
250
- if mode == "minor":
251
- offset = -12
252
- if emo == "Q2":
253
- offset -= 12
254
-
255
- tunes, xml = transpose_octaves_abc(
256
- tunes,
257
- f"{outdir}/{timestamp}.musicxml",
258
- offset,
259
- )
260
- tunes = tunes.replace(title + title, title)
261
- os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
262
- xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml"
263
-
264
- else:
265
- xml = abc2xml(tunes, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
266
-
267
- audio = xml2(xml, "wav")
268
- if fix_volume != None:
269
- if fix_volume:
270
- adjust_volume(audio, fix_volume)
271
-
272
- elif os.path.exists(audio):
273
- if emo == "Q1":
274
- adjust_volume(audio, 5)
275
-
276
- elif emo == "Q2":
277
- adjust_volume(audio, 10)
278
-
279
- mxl = xml2(xml, "mxl")
280
- midi = xml2(xml, "mid")
281
- pdf, jpg = xml2img(xml)
282
- return audio, midi, pdf, xml, mxl, tunes, jpg
283
-
284
- except Exception as e:
285
- print(f"{e}")
286
- return generate_music(args, emo, weights)
287
 
288
 
289
  def infer_by_template(dataset: str, v: str, a: str, add_chord: bool):
290
- if os.path.exists(TEMP_DIR):
291
- shutil.rmtree(TEMP_DIR)
292
-
293
- os.makedirs(TEMP_DIR, exist_ok=True)
294
  emotion = "Q1"
295
  if v == "Low" and a == "High":
296
  emotion = "Q2"
@@ -301,14 +21,20 @@ def infer_by_template(dataset: str, v: str, a: str, add_chord: bool):
301
  elif v == "High" and a == "Low":
302
  emotion = "Q4"
303
 
304
- parser = argparse.ArgumentParser()
305
- args = get_args(parser)
306
- args.template = True
307
- return generate_music(
308
- args,
309
- emo=emotion,
310
- weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth",
311
- )
 
 
 
 
 
 
312
 
313
 
314
  def infer_by_features(
@@ -320,10 +46,8 @@ def infer_by_features(
320
  rms: int,
321
  add_chord: bool,
322
  ):
323
- if os.path.exists(TEMP_DIR):
324
- shutil.rmtree(TEMP_DIR)
325
-
326
- os.makedirs(TEMP_DIR, exist_ok=True)
327
  emotion = "Q1"
328
  if mode == "Minor" and pitch_std == "High":
329
  emotion = "Q2"
@@ -334,78 +58,99 @@ def infer_by_features(
334
  elif mode == "Major" and pitch_std == "Low":
335
  emotion = "Q4"
336
 
337
- parser = argparse.ArgumentParser()
338
- args = get_args(parser)
339
- args.template = False
340
- return generate_music(
341
- args,
342
- emo=emotion,
343
- weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth",
344
- fix_tempo=tempo,
345
- fix_pitch=octave,
346
- fix_volume=rms,
347
- )
348
-
349
-
350
- def feedback(fixed_emo: str, source_dir="./flagged", target_dir="./feedbacks"):
351
- if not fixed_emo:
352
- return "Please select feedback before submitting! "
353
-
354
- os.makedirs(target_dir, exist_ok=True)
355
- for root, _, files in os.walk(source_dir):
356
- for file in files:
357
- if file.endswith(".mxl"):
358
- prompt_emo = file.split("]")[0][1:]
359
- if prompt_emo != fixed_emo:
360
- file_path = os.path.join(root, file)
361
- target_path = os.path.join(
362
- target_dir, file.replace(".mxl", f"_{fixed_emo}.mxl")
363
- )
364
- shutil.copy(file_path, target_path)
365
- return f"Copied {file_path} to {target_path}"
366
 
367
- else:
368
- return "Thanks for your feedback!"
369
 
370
- return "No .mxl files found in the source directory."
371
 
372
 
373
- def save_template(
374
- label: str,
375
- pitch_std: str,
376
- mode: str,
377
- tempo: int,
378
- octave: int,
379
- rms: int,
380
  ):
381
- if (
382
- label
383
- and pitch_std
384
- and mode
385
- and tempo != None
386
- and octave != None
387
- and rms != None
388
- ):
389
- json_str = json.dumps(
390
- {
391
- "label": label,
392
- "pitch_std": pitch_std == "High",
393
- "mode": mode == "Major",
394
- "tempo": tempo,
395
- "octave": octave,
396
- "volume": rms,
397
- }
398
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
- with open("./feedbacks/templates.jsonl", "a", encoding="utf-8") as file:
401
- file.write(json_str + "\n")
 
 
 
 
 
 
 
402
 
403
 
404
  if __name__ == "__main__":
405
  warnings.filterwarnings("ignore")
406
- if os.path.exists("./flagged"):
407
- shutil.rmtree("./flagged")
408
-
409
  with gr.Blocks() as demo:
410
  gr.Markdown(
411
  "## The current CPU-based version on HuggingFace has slow inference, you can access the GPU-based mirror on [ModelScope](https://www.modelscope.cn/studios/monetjoe/EMelodyGen)"
@@ -423,79 +168,69 @@ if __name__ == "__main__":
423
  label="Dataset",
424
  value="Rough4Q",
425
  )
426
- gr.Markdown("# Generate by emotion condition")
427
- gr.Image(
428
- "https://www.modelscope.cn/studio/monetjoe/EMelodyGen/resolve/master/src/4q.jpg",
429
- show_label=False,
430
- show_download_button=False,
431
- show_fullscreen_button=False,
432
- show_share_button=False,
433
- )
434
- valence_radio = gr.Radio(
435
- ["Low", "High"],
436
- label="Valence (reflects negative-positive levels of emotion)",
437
- value="High",
438
- )
439
- arousal_radio = gr.Radio(
440
- ["Low", "High"],
441
- label="Arousal (reflects the calmness-intensity of the emotion)",
442
- value="High",
443
- )
444
- chord_check = gr.Checkbox(
445
- label="Generate chords (coming soon)",
446
- value=False,
447
- )
448
- gen_btn_1 = gr.Button("Generate")
449
- gr.Markdown("# Generate by feature control")
450
- std_option = gr.Radio(["Low", "High"], label="Pitch SD", value="High")
451
- mode_option = gr.Radio(["Minor", "Major"], label="Mode", value="Major")
452
- tempo_option = gr.Slider(
453
- minimum=40,
454
- maximum=228,
455
- step=1,
456
- value=120,
457
- label="Tempo (BPM)",
458
- )
459
- octave_option = gr.Slider(
460
- minimum=-24,
461
- maximum=24,
462
- step=12,
463
- value=0,
464
- label="Octave (Β±12)",
465
- )
466
- volume_option = gr.Slider(
467
- minimum=-5,
468
- maximum=10,
469
- step=5,
470
- value=0,
471
- label="Volume (dB)",
472
- )
473
- chord_check_2 = gr.Checkbox(
474
- label="Generate chords (coming soon)",
475
- value=False,
476
- )
477
- gen_btn_2 = gr.Button("Generate")
478
- template_radio = gr.Radio(
479
- ["Q1", "Q2", "Q3", "Q4"],
480
- label="The emotion to which the current template belongs",
481
- )
482
- save_btn = gr.Button("Save template")
483
- gr.Markdown(
484
- """
485
- ## Cite
486
- ```bibtex
487
- @inproceedings{Zhou2025EMelodyGen,
488
- title = {EMelodyGen: Emotion-Conditioned Melody Generation in ABC Notation with the Musical Feature Template},
489
- author = {Monan Zhou and Xiaobing Li and Feng Yu and Wei Li},
490
- month = {Mar},
491
- year = {2025},
492
- publisher = {GitHub},
493
- version = {0.1},
494
- url = {https://github.com/monetjoe/EMelodyGen}
495
- }
496
- ```
497
- """
498
- )
499
 
500
  with gr.Column():
501
  wav_audio = gr.Audio(label="Audio", type="filepath")
@@ -506,20 +241,35 @@ if __name__ == "__main__":
506
  abc_textbox = gr.Textbox(label="ABC notation", show_copy_button=True)
507
  staff_img = gr.Image(label="Staff", type="filepath")
508
 
509
- gr.Interface(
510
- fn=feedback,
511
- inputs=gr.Radio(
512
- ["Q1", "Q2", "Q3", "Q4"],
513
- label="Feedback: the emotion you believe the generated result should belong to",
514
- ),
515
- outputs=gr.Textbox(show_copy_button=False, show_label=False),
516
- allow_flagging="never",
517
  )
 
518
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
519
  gen_btn_1.click(
520
  fn=infer_by_template,
521
  inputs=[dataset_option, valence_radio, arousal_radio, chord_check],
522
  outputs=[
 
523
  wav_audio,
524
  midi_file,
525
  pdf_file,
@@ -542,6 +292,7 @@ if __name__ == "__main__":
542
  chord_check,
543
  ],
544
  outputs=[
 
545
  wav_audio,
546
  midi_file,
547
  pdf_file,
@@ -562,6 +313,9 @@ if __name__ == "__main__":
562
  octave_option,
563
  volume_option,
564
  ],
 
565
  )
566
 
 
 
567
  demo.launch()
 
 
1
  import os
2
  import json
 
 
 
3
  import shutil
4
  import argparse
5
  import warnings
6
  import gradio as gr
7
+ from generate import generate_music, get_args
8
+ from utils import WEIGHTS_DIR, TEMP_DIR
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  def infer_by_template(dataset: str, v: str, a: str, add_chord: bool):
12
+ status = "Success"
13
+ audio = midi = pdf = xml = mxl = tunes = jpg = None
 
 
14
  emotion = "Q1"
15
  if v == "Low" and a == "High":
16
  emotion = "Q2"
 
21
  elif v == "High" and a == "Low":
22
  emotion = "Q4"
23
 
24
+ try:
25
+ parser = argparse.ArgumentParser()
26
+ args = get_args(parser)
27
+ args.template = True
28
+ audio, midi, pdf, xml, mxl, tunes, jpg = generate_music(
29
+ args,
30
+ emo=emotion,
31
+ weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth",
32
+ )
33
+
34
+ except Exception as e:
35
+ status = f"{e}"
36
+
37
+ return status, audio, midi, pdf, xml, mxl, tunes, jpg
38
 
39
 
40
  def infer_by_features(
 
46
  rms: int,
47
  add_chord: bool,
48
  ):
49
+ status = "Success"
50
+ audio = midi = pdf = xml = mxl = tunes = jpg = None
 
 
51
  emotion = "Q1"
52
  if mode == "Minor" and pitch_std == "High":
53
  emotion = "Q2"
 
58
  elif mode == "Major" and pitch_std == "Low":
59
  emotion = "Q4"
60
 
61
+ try:
62
+ parser = argparse.ArgumentParser()
63
+ args = get_args(parser)
64
+ args.template = False
65
+ audio, midi, pdf, xml, mxl, tunes, jpg = generate_music(
66
+ args,
67
+ emo=emotion,
68
+ weights=f"{WEIGHTS_DIR}/{dataset.lower()}/weights.pth",
69
+ fix_tempo=tempo,
70
+ fix_pitch=octave,
71
+ fix_volume=rms,
72
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ except Exception as e:
75
+ status = f"{e}"
76
 
77
+ return status, audio, midi, pdf, xml, mxl, tunes, jpg
78
 
79
 
80
+ def feedback(
81
+ fixed_emo: str,
82
+ source_dir=f"./{TEMP_DIR}/output",
83
+ target_dir=f"./{TEMP_DIR}/feedback",
 
 
 
84
  ):
85
+ try:
86
+ if not fixed_emo:
87
+ raise ValueError("Please select feedback before submitting! ")
88
+
89
+ os.makedirs(target_dir, exist_ok=True)
90
+ for root, _, files in os.walk(source_dir):
91
+ for file in files:
92
+ if file.endswith(".mxl"):
93
+ prompt_emo = file.split("]")[0][1:]
94
+ if prompt_emo != fixed_emo:
95
+ file_path = os.path.join(root, file)
96
+ target_path = os.path.join(
97
+ target_dir, file.replace(".mxl", f"_{fixed_emo}.mxl")
98
+ )
99
+ shutil.copy(file_path, target_path)
100
+ return f"Copied {file_path} to {target_path}"
101
+
102
+ else:
103
+ return "Thanks for your feedback!"
104
+
105
+ return "No .mxl files found in the source directory."
106
+
107
+ except Exception as e:
108
+ return f"{e}"
109
+
110
+
111
+ def save_template(label: str, pitch_std: str, mode: str, tempo: int, octave: int, rms):
112
+ status = "Success"
113
+ template = None
114
+ try:
115
+ if (
116
+ label
117
+ and pitch_std
118
+ and mode
119
+ and tempo != None
120
+ and octave != None
121
+ and rms != None
122
+ ):
123
+ json_str = json.dumps(
124
+ {
125
+ "label": label,
126
+ "pitch_std": pitch_std == "High",
127
+ "mode": mode == "Major",
128
+ "tempo": tempo,
129
+ "octave": octave,
130
+ "volume": rms,
131
+ }
132
+ )
133
+
134
+ with open(
135
+ f"./{TEMP_DIR}/feedback/templates.jsonl",
136
+ "a",
137
+ encoding="utf-8",
138
+ ) as file:
139
+ file.write(json_str + "\n")
140
 
141
+ template = f"./{TEMP_DIR}/feedback/templates.jsonl"
142
+
143
+ else:
144
+ raise ValueError("Please check features")
145
+
146
+ except Exception as e:
147
+ status = f"{e}"
148
+
149
+ return status, template
150
 
151
 
152
  if __name__ == "__main__":
153
  warnings.filterwarnings("ignore")
 
 
 
154
  with gr.Blocks() as demo:
155
  gr.Markdown(
156
  "## The current CPU-based version on HuggingFace has slow inference, you can access the GPU-based mirror on [ModelScope](https://www.modelscope.cn/studios/monetjoe/EMelodyGen)"
 
168
  label="Dataset",
169
  value="Rough4Q",
170
  )
171
+ with gr.Tab("By template"):
172
+ gr.Image(
173
+ "https://www.modelscope.cn/studio/monetjoe/EMelodyGen/resolve/master/src/4q.jpg",
174
+ show_label=False,
175
+ show_download_button=False,
176
+ show_fullscreen_button=False,
177
+ show_share_button=False,
178
+ )
179
+ valence_radio = gr.Radio(
180
+ ["Low", "High"],
181
+ label="Valence (reflects negative-positive levels of emotion)",
182
+ value="High",
183
+ )
184
+ arousal_radio = gr.Radio(
185
+ ["Low", "High"],
186
+ label="Arousal (reflects the calmness-intensity of the emotion)",
187
+ value="High",
188
+ )
189
+ chord_check = gr.Checkbox(
190
+ label="Generate chords (coming soon)",
191
+ value=False,
192
+ )
193
+ gen_btn_1 = gr.Button("Generate")
194
+
195
+ with gr.Tab("By feature control"):
196
+ std_option = gr.Radio(
197
+ ["Low", "High"], label="Pitch SD", value="High"
198
+ )
199
+ mode_option = gr.Radio(
200
+ ["Minor", "Major"], label="Mode", value="Major"
201
+ )
202
+ tempo_option = gr.Slider(
203
+ minimum=40,
204
+ maximum=228,
205
+ step=1,
206
+ value=120,
207
+ label="Tempo (BPM)",
208
+ )
209
+ octave_option = gr.Slider(
210
+ minimum=-24,
211
+ maximum=24,
212
+ step=12,
213
+ value=0,
214
+ label="Octave (Β±12)",
215
+ )
216
+ volume_option = gr.Slider(
217
+ minimum=-5,
218
+ maximum=10,
219
+ step=5,
220
+ value=0,
221
+ label="Volume (dB)",
222
+ )
223
+ chord_check_2 = gr.Checkbox(
224
+ label="Generate chords (coming soon)",
225
+ value=False,
226
+ )
227
+ gen_btn_2 = gr.Button("Generate")
228
+ template_radio = gr.Radio(
229
+ ["Q1", "Q2", "Q3", "Q4"],
230
+ label="The emotion to which the current template belongs",
231
+ )
232
+ save_btn = gr.Button("Save template")
233
+ dld_template = gr.File(label="Download template")
 
 
 
 
 
 
 
 
 
 
234
 
235
  with gr.Column():
236
  wav_audio = gr.Audio(label="Audio", type="filepath")
 
241
  abc_textbox = gr.Textbox(label="ABC notation", show_copy_button=True)
242
  staff_img = gr.Image(label="Staff", type="filepath")
243
 
244
+ with gr.Column():
245
+ status_bar = gr.Textbox(label="Status", show_copy_button=True)
246
+ fdb_radio = gr.Radio(
247
+ ["Q1", "Q2", "Q3", "Q4"],
248
+ label="Feedback: the emotion you believe the generated result should belong to",
 
 
 
249
  )
250
+ fdb_btn = gr.Button("Submit")
251
 
252
+ gr.Markdown(
253
+ """## Cite
254
+ ```bibtex
255
+ @inproceedings{Zhou2025EMelodyGen,
256
+ title = {EMelodyGen: Emotion-Conditioned Melody Generation in ABC Notation with the Musical Feature Template},
257
+ author = {Monan Zhou and Xiaobing Li and Feng Yu and Wei Li},
258
+ month = {Mar},
259
+ year = {2025},
260
+ publisher = {GitHub},
261
+ version = {0.1},
262
+ url = {https://github.com/monetjoe/EMelodyGen}
263
+ }
264
+ ```"""
265
+ )
266
+
267
+ # actions
268
  gen_btn_1.click(
269
  fn=infer_by_template,
270
  inputs=[dataset_option, valence_radio, arousal_radio, chord_check],
271
  outputs=[
272
+ status_bar,
273
  wav_audio,
274
  midi_file,
275
  pdf_file,
 
292
  chord_check,
293
  ],
294
  outputs=[
295
+ status_bar,
296
  wav_audio,
297
  midi_file,
298
  pdf_file,
 
313
  octave_option,
314
  volume_option,
315
  ],
316
+ outputs=[status_bar, dld_template],
317
  )
318
 
319
+ fdb_btn.click(fn=feedback, inputs=fdb_radio, outputs=status_bar)
320
+
321
  demo.launch()
generate.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import shutil
4
+ import time
5
+ import torch
6
+ import random
7
+ import argparse
8
+ import soundfile as sf
9
+ from transformers import GPT2Config
10
+ from model import Patchilizer, TunesFormer
11
+ from convert import abc2xml, xml2img, xml2, transpose_octaves_abc
12
+ from utils import (
13
+ PATCH_NUM_LAYERS,
14
+ PATCH_LENGTH,
15
+ CHAR_NUM_LAYERS,
16
+ PATCH_SIZE,
17
+ SHARE_WEIGHTS,
18
+ TEMP_DIR,
19
+ DEVICE,
20
+ )
21
+
22
+
23
+ def get_args(parser: argparse.ArgumentParser):
24
+ parser.add_argument(
25
+ "-num_tunes",
26
+ type=int,
27
+ default=1,
28
+ help="the number of independently computed returned tunes",
29
+ )
30
+ parser.add_argument(
31
+ "-max_patch",
32
+ type=int,
33
+ default=128,
34
+ help="integer to define the maximum length in tokens of each tune",
35
+ )
36
+ parser.add_argument(
37
+ "-top_p",
38
+ type=float,
39
+ default=0.8,
40
+ help="float to define the tokens that are within the sample operation of text generation",
41
+ )
42
+ parser.add_argument(
43
+ "-top_k",
44
+ type=int,
45
+ default=8,
46
+ help="integer to define the tokens that are within the sample operation of text generation",
47
+ )
48
+ parser.add_argument(
49
+ "-temperature",
50
+ type=float,
51
+ default=1.2,
52
+ help="the temperature of the sampling operation",
53
+ )
54
+ parser.add_argument("-seed", type=int, default=None, help="seed for randomstate")
55
+ parser.add_argument(
56
+ "-show_control_code",
57
+ type=bool,
58
+ default=False,
59
+ help="whether to show control code",
60
+ )
61
+ parser.add_argument(
62
+ "-template",
63
+ type=bool,
64
+ default=True,
65
+ help="whether to generate by template",
66
+ )
67
+ return parser.parse_args()
68
+
69
+
70
+ def get_abc_key_val(text: str, key="K"):
71
+ pattern = re.escape(key) + r":(.*?)\n"
72
+ match = re.search(pattern, text)
73
+ if match:
74
+ return match.group(1).strip()
75
+ else:
76
+ return None
77
+
78
+
79
+ def adjust_volume(in_audio: str, dB_change: int):
80
+ y, sr = sf.read(in_audio)
81
+ sf.write(in_audio, y * 10 ** (dB_change / 20), sr)
82
+
83
+
84
+ def clean_dir(dir_path: str):
85
+ if os.path.exists(dir_path):
86
+ shutil.rmtree(dir_path)
87
+
88
+ os.makedirs(dir_path)
89
+
90
+
91
+ def generate_music(
92
+ args,
93
+ emo: str,
94
+ weights: str,
95
+ outdir=f"{TEMP_DIR}/output",
96
+ fix_tempo=None,
97
+ fix_pitch=None,
98
+ fix_volume=None,
99
+ ):
100
+ clean_dir(outdir)
101
+ patchilizer = Patchilizer()
102
+ patch_config = GPT2Config(
103
+ num_hidden_layers=PATCH_NUM_LAYERS,
104
+ max_length=PATCH_LENGTH,
105
+ max_position_embeddings=PATCH_LENGTH,
106
+ vocab_size=1,
107
+ )
108
+ char_config = GPT2Config(
109
+ num_hidden_layers=CHAR_NUM_LAYERS,
110
+ max_length=PATCH_SIZE,
111
+ max_position_embeddings=PATCH_SIZE,
112
+ vocab_size=128,
113
+ )
114
+ model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
115
+ checkpoint = torch.load(weights, map_location=DEVICE)
116
+ model.load_state_dict(checkpoint["model"])
117
+ model = model.to(DEVICE)
118
+ model.eval()
119
+ prompt = f"A:{emo}\n"
120
+ tunes = ""
121
+ num_tunes = args.num_tunes
122
+ max_patch = args.max_patch
123
+ top_p = args.top_p
124
+ top_k = args.top_k
125
+ temperature = args.temperature
126
+ seed = args.seed
127
+ show_control_code = args.show_control_code
128
+ fname_prefix = emo if args.template else "Melody"
129
+ print(" Hyper parms ".center(60, "#"), "\n")
130
+ args_dict: dict = vars(args)
131
+ for arg in args_dict.keys():
132
+ print(f"{arg}: {str(args_dict[arg])}")
133
+
134
+ print("\n", " Output tunes ".center(60, "#"))
135
+ start_time = time.time()
136
+ for i in range(num_tunes):
137
+ title = f"T:{fname_prefix} Fragment\n"
138
+ artist = f"C:Generated by AI\n"
139
+ tune = f"X:{str(i + 1)}\n{title}{artist}{prompt}"
140
+ lines = re.split(r"(\n)", tune)
141
+ tune = ""
142
+ skip = False
143
+ for line in lines:
144
+ if show_control_code or line[:2] not in ["S:", "B:", "E:", "D:"]:
145
+ if not skip:
146
+ print(line, end="")
147
+ tune += line
148
+
149
+ skip = False
150
+
151
+ else:
152
+ skip = True
153
+
154
+ input_patches = torch.tensor(
155
+ [patchilizer.encode(prompt, add_special_patches=True)[:-1]],
156
+ device=DEVICE,
157
+ )
158
+ if tune == "":
159
+ tokens = None
160
+
161
+ else:
162
+ prefix = patchilizer.decode(input_patches[0])
163
+ remaining_tokens = prompt[len(prefix) :]
164
+ tokens = torch.tensor(
165
+ [patchilizer.bos_token_id] + [ord(c) for c in remaining_tokens],
166
+ device=DEVICE,
167
+ )
168
+
169
+ while input_patches.shape[1] < max_patch:
170
+ predicted_patch, seed = model.generate(
171
+ input_patches,
172
+ tokens,
173
+ top_p=top_p,
174
+ top_k=top_k,
175
+ temperature=temperature,
176
+ seed=seed,
177
+ )
178
+ tokens = None
179
+ if predicted_patch[0] != patchilizer.eos_token_id:
180
+ next_bar = patchilizer.decode([predicted_patch])
181
+ if show_control_code or next_bar[:2] not in ["S:", "B:", "E:", "D:"]:
182
+ print(next_bar, end="")
183
+ tune += next_bar
184
+
185
+ if next_bar == "":
186
+ break
187
+
188
+ next_bar = remaining_tokens + next_bar
189
+ remaining_tokens = ""
190
+ predicted_patch = torch.tensor(
191
+ patchilizer.bar2patch(next_bar),
192
+ device=DEVICE,
193
+ ).unsqueeze(0)
194
+ input_patches = torch.cat(
195
+ [input_patches, predicted_patch.unsqueeze(0)],
196
+ dim=1,
197
+ )
198
+
199
+ else:
200
+ break
201
+
202
+ tunes += f"{tune}\n\n"
203
+ print("\n")
204
+
205
+ # fix tempo
206
+ if fix_tempo != None:
207
+ tempo = f"Q:{fix_tempo}\n"
208
+
209
+ else:
210
+ tempo = f"Q:{random.randint(88, 132)}\n"
211
+ if emo == "Q1":
212
+ tempo = f"Q:{random.randint(160, 184)}\n"
213
+ elif emo == "Q2":
214
+ tempo = f"Q:{random.randint(184, 228)}\n"
215
+ elif emo == "Q3":
216
+ tempo = f"Q:{random.randint(40, 69)}\n"
217
+ elif emo == "Q4":
218
+ tempo = f"Q:{random.randint(40, 69)}\n"
219
+
220
+ Q_val = get_abc_key_val(tunes, "Q")
221
+ if Q_val:
222
+ tunes = tunes.replace(f"Q:{Q_val}\n", "")
223
+
224
+ K_val = get_abc_key_val(tunes)
225
+ if K_val == "none":
226
+ K_val = "C"
227
+ tunes = tunes.replace("K:none\n", f"K:{K_val}\n")
228
+
229
+ tunes = tunes.replace(f"A:{emo}\n", tempo)
230
+ # fix mode:major/minor
231
+ mode = "major" if emo == "Q1" or emo == "Q4" else "minor"
232
+ if (mode == "major") and ("m" in K_val):
233
+ tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.split('m')[0]}\n")
234
+
235
+ elif (mode == "minor") and (not "m" in K_val):
236
+ tunes = tunes.replace(f"\nK:{K_val}\n", f"\nK:{K_val.replace('dor', '')}min\n")
237
+
238
+ print("Generation time: {:.2f} seconds".format(time.time() - start_time))
239
+ timestamp = time.strftime("%a_%d_%b_%Y_%H_%M_%S", time.localtime())
240
+ try:
241
+ # fix avg_pitch (octave)
242
+ if fix_pitch != None:
243
+ if fix_pitch:
244
+ tunes, xml = transpose_octaves_abc(
245
+ tunes,
246
+ f"{outdir}/{timestamp}.musicxml",
247
+ fix_pitch,
248
+ )
249
+ tunes = tunes.replace(title + title, title)
250
+ os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
251
+ xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml"
252
+
253
+ else:
254
+ if mode == "minor":
255
+ offset = -12
256
+ if emo == "Q2":
257
+ offset -= 12
258
+
259
+ tunes, xml = transpose_octaves_abc(
260
+ tunes,
261
+ f"{outdir}/{timestamp}.musicxml",
262
+ offset,
263
+ )
264
+ tunes = tunes.replace(title + title, title)
265
+ os.rename(xml, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
266
+ xml = f"{outdir}/[{fname_prefix}]{timestamp}.musicxml"
267
+
268
+ else:
269
+ xml = abc2xml(tunes, f"{outdir}/[{fname_prefix}]{timestamp}.musicxml")
270
+
271
+ audio = xml2(xml, "wav")
272
+ if fix_volume != None:
273
+ if fix_volume:
274
+ adjust_volume(audio, fix_volume)
275
+
276
+ elif os.path.exists(audio):
277
+ if emo == "Q1":
278
+ adjust_volume(audio, 5)
279
+
280
+ elif emo == "Q2":
281
+ adjust_volume(audio, 10)
282
+
283
+ mxl = xml2(xml, "mxl")
284
+ midi = xml2(xml, "mid")
285
+ pdf, jpg = xml2img(xml)
286
+ return audio, midi, pdf, xml, mxl, tunes, jpg
287
+
288
+ except Exception as e:
289
+ print(f"{e}")
290
+ return generate_music(args, emo, weights)
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  torch
2
  music21
3
  pymupdf
4
- autopep8
5
  soundfile
6
  unidecode
7
  pillow==9.4.0
 
1
  torch
2
  music21
3
  pymupdf
 
4
  soundfile
5
  unidecode
6
  pillow==9.4.0
utils.py CHANGED
@@ -4,11 +4,16 @@ import time
4
  import torch
5
  import requests
6
  import subprocess
 
 
7
  from tqdm import tqdm
8
- from huggingface_hub import snapshot_download
9
 
10
- TEMP_DIR = "./flagged"
11
- WEIGHTS_DIR = snapshot_download("monetjoe/EMelodyGen", cache_dir="./__pycache__")
 
 
 
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  PATCH_LENGTH = 128 # Patch Length
14
  PATCH_SIZE = 32 # Patch Size
 
4
  import torch
5
  import requests
6
  import subprocess
7
+ import modelscope
8
+ import huggingface_hub
9
  from tqdm import tqdm
 
10
 
11
+ TEMP_DIR = "./__pycache__"
12
+ WEIGHTS_DIR = (
13
+ huggingface_hub.snapshot_download("monetjoe/EMelodyGen", cache_dir=TEMP_DIR)
14
+ if os.getenv("language")
15
+ else modelscope.snapshot_download("monetjoe/EMelodyGen", cache_dir=TEMP_DIR)
16
+ )
17
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  PATCH_LENGTH = 128 # Patch Length
19
  PATCH_SIZE = 32 # Patch Size