Hjgugugjhuhjggg commited on
Commit
5a91b88
·
verified ·
1 Parent(s): be553d9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +11 -6
main.py CHANGED
@@ -9,6 +9,7 @@ from numpy import ndarray
9
  from note_seq.protobuf.music_pb2 import NoteSequence
10
  from note_seq.constants import STANDARD_PPQ
11
  import logging
 
12
 
13
  logging.basicConfig(level=logging.INFO)
14
 
@@ -145,6 +146,8 @@ GM_INSTRUMENTS = [
145
  ]
146
  tokenizer = None
147
  model = None
 
 
148
  def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
149
  logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...")
150
  global model, tokenizer
@@ -311,9 +314,9 @@ def get_outputs_from_string(
311
  instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
312
  note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
313
 
314
- if not note_sequence.notes: # Check if note_sequence is empty
315
  logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
316
- fig = None # Handle case where fig is None
317
  else:
318
  fig = note_seq.plot_sequence(note_sequence, show_figure=False)
319
 
@@ -379,16 +382,18 @@ def generate_song(
379
  text_sequence: str = "",
380
  qpm: int = 120,
381
  prompt: str = "",
382
- duration: int = 1
383
  ) -> Tuple[ndarray, str, Figure, str, str, str]:
384
- logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration}, Prompt: '{prompt}'")
385
  if text_sequence == "":
386
  seed_string = create_seed_string(genre, prompt)
387
  else:
388
  seed_string = text_sequence
389
 
 
 
390
  generated_sequence = seed_string
391
- for _ in range(duration):
392
  instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp)
393
  if instrument_sequence:
394
  generated_sequence = instrument_sequence
@@ -410,7 +415,7 @@ def run():
410
  with gr.Row():
411
  with gr.Column():
412
  prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)")
413
- duration_slider = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Duration (Tracks)")
414
  temp = gr.Slider(
415
  minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature"
416
  )
 
9
  from note_seq.protobuf.music_pb2 import NoteSequence
10
  from note_seq.constants import STANDARD_PPQ
11
  import logging
12
+ import math
13
 
14
  logging.basicConfig(level=logging.INFO)
15
 
 
146
  ]
147
  tokenizer = None
148
  model = None
149
+
150
+
151
  def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
152
  logging.info("get_model_and_tokenizer: Starting to load model and tokenizer...")
153
  global model, tokenizer
 
314
  instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
315
  note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
316
 
317
+ if not note_sequence.notes:
318
  logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
319
+ fig = None
320
  else:
321
  fig = note_seq.plot_sequence(note_sequence, show_figure=False)
322
 
 
382
  text_sequence: str = "",
383
  qpm: int = 120,
384
  prompt: str = "",
385
+ duration: int = 30
386
  ) -> Tuple[ndarray, str, Figure, str, str, str]:
387
+ logging.info(f"generate_song: Starting song generation. Genre: {genre}, Temperature: {temp}, QPM: {qpm}, Duration: {duration} seconds, Prompt: '{prompt}'")
388
  if text_sequence == "":
389
  seed_string = create_seed_string(genre, prompt)
390
  else:
391
  seed_string = text_sequence
392
 
393
+ num_tracks = max(1, int(math.ceil(duration / 17)))
394
+
395
  generated_sequence = seed_string
396
+ for _ in range(num_tracks):
397
  instrument_sequence = generate_new_instrument(seed=generated_sequence, temp=temp)
398
  if instrument_sequence:
399
  generated_sequence = instrument_sequence
 
415
  with gr.Row():
416
  with gr.Column():
417
  prompt_text = gr.Textbox(lines=2, placeholder="Enter text prompt here...", label="Text Prompt (Optional)")
418
+ duration_slider = gr.Slider(minimum=1, maximum=1000, step=1, value=30, label="Duration (Seconds)")
419
  temp = gr.Slider(
420
  minimum=0, maximum=1, step=0.05, value=0.85, label="Temperature"
421
  )