AbstractPhil commited on
Commit
9228a8c
·
verified ·
1 Parent(s): 37b55c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -133,7 +133,8 @@ def encode_and_trace(text, selected_roles):
133
  masked_input = ids.where(mask_flags, MASK_ID)
134
 
135
  encoded_m = encode(masked_input, attn)
136
- logits = mlm_head(encoded_m).logits[0]
 
137
  preds = logits.argmax(-1)
138
 
139
  masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)
 
133
  masked_input = ids.where(mask_flags, MASK_ID)
134
 
135
  encoded_m = encode(masked_input, attn)
136
+ logits = mlm_head(encoded_m)[0] # shape: (S, V)
137
+
138
  preds = logits.argmax(-1)
139
 
140
  masked_positions = (~mask_flags[0]).nonzero(as_tuple=False).squeeze(-1)