nguyenbh Eraa commited on
Commit
4943451
·
verified ·
1 Parent(s): 6cf9696

Support this model device-independent. (#52)

Browse files

- Support this model device-independent. (61ddce6553b8beffc5a859ba31726aab1f9b6979)


Co-authored-by: yejinglai <[email protected]>

Files changed (1) hide show
  1. speech_conformer_encoder.py +4 -5
speech_conformer_encoder.py CHANGED
@@ -2477,9 +2477,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
2477
  seq_len, batch_size, self.chunk_size, self.left_chunk
2478
  )
2479
 
2480
- if xs_pad.is_cuda:
2481
- enc_streaming_mask = enc_streaming_mask.cuda()
2482
- xs_pad = xs_pad.cuda()
2483
 
2484
  input_tensor = xs_pad
2485
  input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
@@ -2496,8 +2495,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
2496
  enc_streaming_mask_nc = self._streaming_mask(
2497
  seq_len, batch_size, chunk_size_nc, left_chunk_nc
2498
  )
2499
- if xs_pad.is_cuda:
2500
- enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
2501
  if masks is not None:
2502
  hs_mask_nc = masks & enc_streaming_mask_nc
2503
  else:
 
2477
  seq_len, batch_size, self.chunk_size, self.left_chunk
2478
  )
2479
 
2480
+ if xs_pad.device != "cpu":
2481
+ enc_streaming_mask = enc_streaming_mask.to(xs_pad.device)
 
2482
 
2483
  input_tensor = xs_pad
2484
  input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
 
2495
  enc_streaming_mask_nc = self._streaming_mask(
2496
  seq_len, batch_size, chunk_size_nc, left_chunk_nc
2497
  )
2498
+ if xs_pad.device != "cpu":
2499
+ enc_streaming_mask_nc = enc_streaming_mask_nc.to(xs_pad.device)
2500
  if masks is not None:
2501
  hs_mask_nc = masks & enc_streaming_mask_nc
2502
  else: