asigalov61 commited on
Commit
0de0265
·
1 Parent(s): 173ed3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -45,8 +45,14 @@ def GenerateMIDI():
45
  torch_in = x.tolist()[0]
46
 
47
  logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
 
 
 
 
 
 
48
 
49
- probs = F.softmax(logits / temperature, dim=-1)
50
 
51
  sample = torch.multinomial(probs, 1)
52
 
@@ -133,7 +139,7 @@ def GenerateMIDI():
133
 
134
  audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
135
 
136
- yield output, "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio)
137
 
138
  #=================================================================================================
139
 
@@ -143,7 +149,7 @@ def cancel_run(output_midi_seq):
143
  with open(f"Allegro-Music-Transformer-Music-Composition.mid", 'wb') as f:
144
  f.write(TMIDIX.score2midi(output_midi_seq))
145
  audio = synthesis(TMIDIX.score2opus(output_midi_seq), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
146
- return "Allegro-Music-Transformer-Music-Composition.mid", (44100, audio)
147
 
148
  #=================================================================================================
149
 
 
45
  torch_in = x.tolist()[0]
46
 
47
  logits = torch.FloatTensor(session.run(None, {'input': [torch_in]})[0])[:, -1]
48
+
49
+ thres = 0.9
50
+ k = ceil((1 - thres) * logits.shape[-1])
51
+ val, ind = torch.topk(logits, k)
52
+ probs = torch.full_like(logits, float('-inf'))
53
+ probs.scatter_(1, ind, val)
54
 
55
+ probs = F.softmax(probs / temperature, dim=-1)
56
 
57
  sample = torch.multinomial(probs, 1)
58
 
 
139
 
140
  audio = synthesis(TMIDIX.score2opus(output), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
141
 
142
+ yield output, "Allegro-Music-Transformer-Music-Composition.mid", (16000, audio)
143
 
144
  #=================================================================================================
145
 
 
149
  with open(f"Allegro-Music-Transformer-Music-Composition.mid", 'wb') as f:
150
  f.write(TMIDIX.score2midi(output_midi_seq))
151
  audio = synthesis(TMIDIX.score2opus(output_midi_seq), 'SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2')
152
+ return "Allegro-Music-Transformer-Music-Composition.mid", (16000, audio)
153
 
154
  #=================================================================================================
155