Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -207,55 +207,55 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type) 
     | 
|
| 207 | 
         | 
| 208 | 
         
             
                output = []
         
     | 
| 209 | 
         | 
| 210 | 
         
            -
                 
     | 
| 211 | 
         
            -
                num_toks_per_note = 32
         
     | 
| 212 | 
         
             
                temperature=0.9
         
     | 
| 213 | 
         
            -
                max_drums_limit=4
         
     | 
| 214 | 
         
             
                num_memory_tokens=4096
         
     | 
| 215 | 
         | 
| 216 | 
         
            -
                 
     | 
| 217 | 
         
            -
                output2 = []
         
     | 
| 218 | 
         | 
| 219 | 
         
            -
             
     | 
| 220 | 
         
            -
                for m in melody_chords[:input_num_tokens]:
         
     | 
| 221 | 
         | 
| 222 | 
         
            -
             
     | 
| 223 | 
         
            -
             
     | 
| 224 | 
         
            -
                    input_seq = output1
         
     | 
| 225 | 
         | 
| 226 | 
         
            -
             
     | 
| 227 | 
         
            -
                      x = torch.LongTensor([input_seq+[0]]).cuda()
         
     | 
| 228 | 
         
            -
                    else:
         
     | 
| 229 | 
         
            -
                      x = torch.LongTensor([input_seq]).cuda()
         
     | 
| 230 | 
         | 
| 231 | 
         
            -
                     
     | 
| 232 | 
         | 
| 233 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 234 | 
         | 
| 235 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 236 | 
         
             
                      with ctx:
         
     | 
| 237 | 
         
             
                        out = model.generate(x[-num_memory_tokens:],
         
     | 
| 238 | 
         
             
                                            1,
         
     | 
| 239 | 
         
             
                                            temperature=temperature,
         
     | 
| 240 | 
         
             
                                            return_prime=False,
         
     | 
| 241 | 
         
             
                                            verbose=False)
         
     | 
| 242 | 
         
            -
             
     | 
| 243 | 
         
             
                      o = out.tolist()[0][0]
         
     | 
| 244 | 
         
            -
             
     | 
| 245 | 
         
            -
                      if  
     | 
| 246 | 
         
            -
                         
     | 
| 247 | 
         
            -
             
     | 
| 248 | 
         
            -
                      if  
     | 
| 249 | 
         
            -
                
         
     | 
| 250 | 
         
            -
                        out = torch.LongTensor([[o]]).cuda()
         
     | 
| 251 | 
         
             
                        x = torch.cat((x, out), 1)
         
     | 
| 252 | 
         
            -
             
     | 
| 253 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 254 | 
         | 
| 255 | 
         
            -
                     
     | 
| 256 | 
         
            -
             
     | 
| 257 | 
         
            -
                    output1.extend(outy)
         
     | 
| 258 | 
         
            -
                    output2.append(outy)
         
     | 
| 259 | 
         | 
| 260 | 
         
             
                print('=' * 70)
         
     | 
| 261 | 
         
             
                print('Done!')
         
     | 
| 
         @@ -265,13 +265,10 @@ def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type) 
     | 
|
| 265 | 
         
             
                print('Rendering results...')
         
     | 
| 266 | 
         | 
| 267 | 
         
             
                print('=' * 70)
         
     | 
| 268 | 
         
            -
                print('Sample INTs',  
     | 
| 269 | 
         
             
                print('=' * 70)
         
     | 
| 270 | 
         | 
| 271 | 
         
            -
                out1 =  
     | 
| 272 | 
         
            -
             
     | 
| 273 | 
         
            -
                accompaniment_MIDI_patch_number = 0
         
     | 
| 274 | 
         
            -
                melody_MIDI_patch_number = 40
         
     | 
| 275 | 
         | 
| 276 | 
         
             
                if len(out1) != 0:
         
     | 
| 277 | 
         | 
| 
         | 
|
| 207 | 
         | 
| 208 | 
         
             
                output = []
         
     | 
| 209 | 
         | 
| 210 | 
         
            +
                max_chords_limit = 8
         
     | 
| 
         | 
|
| 211 | 
         
             
                temperature=0.9
         
     | 
| 
         | 
|
| 212 | 
         
             
                num_memory_tokens=4096
         
     | 
| 213 | 
         | 
| 214 | 
         
            +
                output = []
         
     | 
| 
         | 
|
| 215 | 
         | 
| 216 | 
         
            +
                idx = 0
         
     | 
| 
         | 
|
| 217 | 
         | 
| 218 | 
         
            +
                for c in chords[:input_num_tokens]:
         
     | 
| 
         | 
|
| 
         | 
|
| 219 | 
         | 
| 220 | 
         
            +
                  try:
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 221 | 
         | 
| 222 | 
         
            +
                    output.append(c)
         
     | 
| 223 | 
         | 
| 224 | 
         
            +
                    if input_conditioning_type == 'Chords-Times' or input_conditioning_type == 'Chords-Times-Durations':
         
     | 
| 225 | 
         
            +
                      output.append(times[idx])
         
     | 
| 226 | 
         
            +
                
         
     | 
| 227 | 
         
            +
                    if input_conditioning_type == 'Chords-Times-Durations':
         
     | 
| 228 | 
         
            +
                      output.append(durs[idx])
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    chords = output
         
     | 
| 231 | 
         | 
| 232 | 
         
            +
                    x = torch.tensor([chords] * 1, dtype=torch.long, device='cuda')
         
     | 
| 233 | 
         
            +
                    
         
     | 
| 234 | 
         
            +
                    o = 0
         
     | 
| 235 | 
         
            +
                    
         
     | 
| 236 | 
         
            +
                    ncount = 0
         
     | 
| 237 | 
         
            +
                    
         
     | 
| 238 | 
         
            +
                    while o < 384 and ncount < max_chords_limit:
         
     | 
| 239 | 
         
             
                      with ctx:
         
     | 
| 240 | 
         
             
                        out = model.generate(x[-num_memory_tokens:],
         
     | 
| 241 | 
         
             
                                            1,
         
     | 
| 242 | 
         
             
                                            temperature=temperature,
         
     | 
| 243 | 
         
             
                                            return_prime=False,
         
     | 
| 244 | 
         
             
                                            verbose=False)
         
     | 
| 245 | 
         
            +
                    
         
     | 
| 246 | 
         
             
                      o = out.tolist()[0][0]
         
     | 
| 247 | 
         
            +
                    
         
     | 
| 248 | 
         
            +
                      if 256 <= o < 384:
         
     | 
| 249 | 
         
            +
                        ncount += 1
         
     | 
| 250 | 
         
            +
                    
         
     | 
| 251 | 
         
            +
                      if o < 384:
         
     | 
| 
         | 
|
| 
         | 
|
| 252 | 
         
             
                        x = torch.cat((x, out), 1)
         
     | 
| 253 | 
         
            +
                    
         
     | 
| 254 | 
         
            +
                    outy =  x.tolist()[0][len(chords):]
         
     | 
| 255 | 
         
            +
                      
         
     | 
| 256 | 
         
            +
                    output.extend(outy)
         
     | 
| 257 | 
         | 
| 258 | 
         
            +
                    idx += 1
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 259 | 
         | 
| 260 | 
         
             
                print('=' * 70)
         
     | 
| 261 | 
         
             
                print('Done!')
         
     | 
| 
         | 
|
| 265 | 
         
             
                print('Rendering results...')
         
     | 
| 266 | 
         | 
| 267 | 
         
             
                print('=' * 70)
         
     | 
| 268 | 
         
            +
                print('Sample INTs', output[:12])
         
     | 
| 269 | 
         
             
                print('=' * 70)
         
     | 
| 270 | 
         | 
| 271 | 
         
            +
                out1 = output
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 272 | 
         | 
| 273 | 
         
             
                if len(out1) != 0:
         
     | 
| 274 | 
         |