Hjgugugjhuhjggg commited on
Commit
be553d9
·
verified ·
1 Parent(s): 3017618

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +12 -5
main.py CHANGED
@@ -263,7 +263,7 @@ model, tokenizer = get_model_and_tokenizer()
263
  def create_seed_string(genre: str = "OTHER", prompt: str = "") -> str:
264
  logging.info(f"create_seed_string: Creating seed string. Genre: {genre}, Prompt: '{prompt}'")
265
  if prompt:
266
- seed_string = f"PIECE_START PROMPT={prompt} GENRE={genre} TRACK_START" # Incorporate prompt more directly
267
  elif genre == "RANDOM":
268
  seed_string = "PIECE_START"
269
  else:
@@ -281,7 +281,7 @@ def get_instruments(text_sequence: str) -> List[str]:
281
  index = int(part[5:])
282
  instruments.append(GM_INSTRUMENTS[index])
283
  return instruments
284
- def generate_new_instrument(seed: str, temp: float = 0.75, max_tokens=204) -> str: # Increased max_tokens
285
  logging.info(f"generate_new_instrument: Starting instrument generation. Seed: '{seed}', Temperature: {temp}, Max Tokens: {max_tokens}")
286
  seed_length = len(tokenizer.encode(seed))
287
  input_ids = tokenizer.encode(seed, return_tensors="pt").to(model.device)
@@ -295,7 +295,7 @@ def generate_new_instrument(seed: str, temp: float = 0.75, max_tokens=204) -> st
295
  )
296
  generated_sequence = tokenizer.decode(generated_ids[0])
297
  new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
298
- logging.info(f"generate_new_instrument: Generated sequence: '{new_generated_sequence}'") # Log generated sequence
299
  if "NOTE_ON" in new_generated_sequence:
300
  logging.info("generate_new_instrument: New instrument generated successfully.")
301
  return generated_sequence
@@ -310,15 +310,22 @@ def get_outputs_from_string(
310
  instruments = get_instruments(generated_sequence)
311
  instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
312
  note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
 
 
 
 
 
 
 
313
  synth = note_seq.fluidsynth
314
  array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
315
  int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
316
- fig = note_seq.plot_sequence(note_sequence, show_figure=False)
317
  num_tokens = str(len(generated_sequence.split()))
318
  audio = gr.make_waveform((SAMPLE_RATE, int16_data))
319
  note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
320
  logging.info("get_outputs_from_string: Output generation complete.")
321
  return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
 
322
  def remove_last_instrument(
323
  text_sequence: str, qpm: int = 120
324
  ) -> Tuple[ndarray, str, Figure, str, str, str]:
@@ -368,7 +375,7 @@ def change_tempo(
368
  return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
369
  def generate_song(
370
  genre: str = "OTHER",
371
- temp: float = 0.85, # Default temperature to 0.85 as in UI
372
  text_sequence: str = "",
373
  qpm: int = 120,
374
  prompt: str = "",
 
263
  def create_seed_string(genre: str = "OTHER", prompt: str = "") -> str:
264
  logging.info(f"create_seed_string: Creating seed string. Genre: {genre}, Prompt: '{prompt}'")
265
  if prompt:
266
+ seed_string = f"PIECE_START PROMPT={prompt} GENRE={genre} TRACK_START"
267
  elif genre == "RANDOM":
268
  seed_string = "PIECE_START"
269
  else:
 
281
  index = int(part[5:])
282
  instruments.append(GM_INSTRUMENTS[index])
283
  return instruments
284
+ def generate_new_instrument(seed: str, temp: float = 0.85, max_tokens=512) -> str:
285
  logging.info(f"generate_new_instrument: Starting instrument generation. Seed: '{seed}', Temperature: {temp}, Max Tokens: {max_tokens}")
286
  seed_length = len(tokenizer.encode(seed))
287
  input_ids = tokenizer.encode(seed, return_tensors="pt").to(model.device)
 
295
  )
296
  generated_sequence = tokenizer.decode(generated_ids[0])
297
  new_generated_sequence = tokenizer.decode(generated_ids[0][seed_length:])
298
+ logging.info(f"generate_new_instrument: Generated sequence: '{new_generated_sequence}'")
299
  if "NOTE_ON" in new_generated_sequence:
300
  logging.info("generate_new_instrument: New instrument generated successfully.")
301
  return generated_sequence
 
310
  instruments = get_instruments(generated_sequence)
311
  instruments_str = "\n".join(f"- {instrument}" for instrument in instruments)
312
  note_sequence = token_sequence_to_note_sequence(generated_sequence, qpm=qpm)
313
+
314
+ if not note_sequence.notes: # Check if note_sequence is empty
315
+ logging.warning("get_outputs_from_string: Note sequence is empty, skipping plot.")
316
+ fig = None # Handle case where fig is None
317
+ else:
318
+ fig = note_seq.plot_sequence(note_sequence, show_figure=False)
319
+
320
  synth = note_seq.fluidsynth
321
  array_of_floats = synth(note_sequence, sample_rate=SAMPLE_RATE)
322
  int16_data = note_seq.audio_io.float_samples_to_int16(array_of_floats)
 
323
  num_tokens = str(len(generated_sequence.split()))
324
  audio = gr.make_waveform((SAMPLE_RATE, int16_data))
325
  note_seq.note_sequence_to_midi_file(note_sequence, "midi_ouput.mid")
326
  logging.info("get_outputs_from_string: Output generation complete.")
327
  return audio, "midi_ouput.mid", fig, instruments_str, num_tokens
328
+
329
  def remove_last_instrument(
330
  text_sequence: str, qpm: int = 120
331
  ) -> Tuple[ndarray, str, Figure, str, str, str]:
 
375
  return audio, midi_file, fig, instruments_str, text_sequence, num_tokens
376
  def generate_song(
377
  genre: str = "OTHER",
378
+ temp: float = 0.85,
379
  text_sequence: str = "",
380
  qpm: int = 120,
381
  prompt: str = "",