asigalov61 commited on
Commit
12586a9
·
verified ·
1 Parent(s): a174779

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -31,19 +31,24 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
31
 
32
  print('Loading model...')
33
 
34
- SEQ_LEN = 8192 # Models seq len
35
- PAD_IDX = 707 # Models pad index
36
- DEVICE = 'cuda' # 'cuda'
37
 
38
  # instantiate the model
39
 
40
  model = TransformerWrapper(
41
  num_tokens = PAD_IDX+1,
42
  max_seq_len = SEQ_LEN,
43
- attn_layers = Decoder(dim = 2048, depth = 4, heads = 16, attn_flash = True)
 
 
 
 
 
44
  )
45
 
46
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
47
 
48
  model.to(DEVICE)
49
  print('=' * 70)
@@ -51,7 +56,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
51
  print('Loading model checkpoint...')
52
 
53
  model.load_state_dict(
54
- torch.load('Chords_Progressions_Transformer_Small_2048_Trained_Model_12947_steps_0.9316_loss_0.7386_acc.pth',
55
  map_location=DEVICE))
56
  print('=' * 70)
57
 
@@ -62,7 +67,7 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type,
62
  else:
63
  dtype = torch.float16
64
 
65
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
66
 
67
  print('Done!')
68
  print('=' * 70)
 
31
 
32
  print('Loading model...')
33
 
34
+ SEQ_LEN = 8192
35
+ PAD_IDX = 2239
36
+ DEVICE = 'cuda' # 'cpu'
37
 
38
  # instantiate the model
39
 
40
  model = TransformerWrapper(
41
  num_tokens = PAD_IDX+1,
42
  max_seq_len = SEQ_LEN,
43
+ attn_layers = Decoder(dim = 2048,
44
+ depth = 8,
45
+ heads = 32,
46
+ rotary_pos_emb = True,
47
+ attn_flash = True
48
+ )
49
  )
50
 
51
+ model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)
52
 
53
  model.to(DEVICE)
54
  print('=' * 70)
 
56
  print('Loading model checkpoint...')
57
 
58
  model.load_state_dict(
59
+ torch.load('Ultimate_Chords_Progressions_Transformer_Trained_Model_LAX_5858_steps_0.4506_loss_0.8724_acc.pth',
60
  map_location=DEVICE))
61
  print('=' * 70)
62
 
 
67
  else:
68
  dtype = torch.float16
69
 
70
+ ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)
71
 
72
  print('Done!')
73
  print('=' * 70)