Crystalcareai commited on
Commit
66b0a6e
·
verified ·
1 Parent(s): 833b955

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -11
modeling_quiet.py CHANGED
@@ -1072,16 +1072,7 @@ class QuietModel(QuietPreTrainedModel):
1072
  )
1073
 
1074
  if attention_mask is None:
1075
- if input_ids is not None:
1076
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
1077
- elif inputs_embeds is not None:
1078
- attention_mask = torch.ones(
1079
- (batch_size, seq_len),
1080
- dtype=torch.bool,
1081
- device=inputs_embeds.device
1082
- )
1083
- else:
1084
- raise ValueError("Either input_ids or inputs_embeds should be provided.")
1085
 
1086
  if attention_mask.dim() == 2:
1087
  attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
@@ -1091,7 +1082,7 @@ class QuietModel(QuietPreTrainedModel):
1091
  elif attention_mask.dim() != 4:
1092
  raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len) or (batch_size, 1, 1, seq_len), but got {attention_mask.shape}")
1093
 
1094
- attention_mask = attention_mask.to(dtype=torch.bool, device=input_ids.device)
1095
 
1096
  hidden_states = inputs_embeds
1097
 
 
1072
  )
1073
 
1074
  if attention_mask is None:
1075
+ attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=inputs_embeds.device)
 
 
 
 
 
 
 
 
 
1076
 
1077
  if attention_mask.dim() == 2:
1078
  attention_mask = attention_mask.view(batch_size, 1, 1, seq_len)
 
1082
  elif attention_mask.dim() != 4:
1083
  raise ValueError(f"Attention mask should be of shape (batch_size, 1, seq_len, seq_len) or (batch_size, 1, 1, seq_len), but got {attention_mask.shape}")
1084
 
1085
+ attention_mask = attention_mask.to(dtype=torch.bool)
1086
 
1087
  hidden_states = inputs_embeds
1088