asigalov61 commited on
Commit
e2d614e
·
verified ·
1 Parent(s): 50c0119

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -28,8 +28,9 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
28
  print('=' * 70)
29
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
30
  start_time = reqtime.time()
31
-
32
- print('Loading model...')
 
33
 
34
  SEQ_LEN = 8192
35
  PAD_IDX = 2239
@@ -51,6 +52,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
51
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
52
 
53
  model.to(DEVICE)
 
54
  print('=' * 70)
55
 
56
  print('Loading model checkpoint...')
@@ -58,6 +60,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
58
  model.load_state_dict(
59
  torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth',
60
  map_location=DEVICE))
 
61
  print('=' * 70)
62
 
63
  model.eval()
@@ -65,9 +68,9 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
65
  if DEVICE == 'cpu':
66
  dtype = torch.bfloat16
67
  else:
68
- dtype = torch.float16
69
 
70
- ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
71
 
72
  print('Done!')
73
  print('=' * 70)
@@ -77,12 +80,12 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
77
 
78
  input_num_tokens = max(4, min(128, input_num_tokens))
79
 
80
- print('-' * 70)
81
  print('Input file name:', fn)
82
  print('Req num toks:', input_num_tokens)
83
  print('Conditioning type:', input_conditioning_type)
84
  print('Strip notes:', input_strip_notes)
85
- print('-' * 70)
86
 
87
  #===============================================================================
88
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)
 
28
  print('=' * 70)
29
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
30
  start_time = reqtime.time()
31
+
32
+ print('=' * 70)
33
+ print('Instantiating the model...')
34
 
35
  SEQ_LEN = 8192
36
  PAD_IDX = 2239
 
52
  model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
53
 
54
  model.to(DEVICE)
55
+ print('Done!')
56
  print('=' * 70)
57
 
58
  print('Loading model checkpoint...')
 
60
  model.load_state_dict(
61
  torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth',
62
  map_location=DEVICE))
63
+ print('Done!')
64
  print('=' * 70)
65
 
66
  model.eval()
 
68
  if DEVICE == 'cpu':
69
  dtype = torch.bfloat16
70
  else:
71
+ dtype = torch.bfloat16
72
 
73
+ ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
74
 
75
  print('Done!')
76
  print('=' * 70)
 
80
 
81
  input_num_tokens = max(4, min(128, input_num_tokens))
82
 
83
+ print('=' * 70)
84
  print('Input file name:', fn)
85
  print('Req num toks:', input_num_tokens)
86
  print('Conditioning type:', input_conditioning_type)
87
  print('Strip notes:', input_strip_notes)
88
+ print('=' * 70)
89
 
90
  #===============================================================================
91
  raw_score = TMIDIX.midi2single_track_ms_score(input_midi.name)