Crystalcareai commited on
Commit
fea5225
·
verified ·
1 Parent(s): 9446754

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +53 -52
modeling_quiet.py CHANGED
@@ -1169,6 +1169,59 @@ def nonzero_mean(x, axis=None):
1169
  def loss_mean(x):
1170
  return x.sum() / (x != 0).sum()
1171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1172
  class QuietForCausalLM(QuietPreTrainedModel, QuietGenerationMixin):
1173
  _tied_weights_keys = ["lm_head.weight"]
1174
 
@@ -2228,58 +2281,6 @@ class QuietForCausalLM(QuietPreTrainedModel, QuietGenerationMixin):
2228
  QUIET_START_DOCSTRING,
2229
  )
2230
 
2231
- class QuietGenerationMixin(GenerationMixin):
2232
- def generate(self, input_ids, attention_mask=None, **generate_kwargs):
2233
- if attention_mask is None:
2234
- attention_mask = torch.ones_like(input_ids)
2235
-
2236
- max_length = generate_kwargs.get("max_length", 20)
2237
- temp = generate_kwargs.get("temperature", 1.0)
2238
-
2239
- finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
2240
-
2241
- for cur_token_idx in range(max_length):
2242
- # Sample the next token
2243
- new_ids = self(
2244
- input_ids[~finished_generating],
2245
- attention_mask=attention_mask[~finished_generating]
2246
- )['logits']
2247
-
2248
- # Mask out the start and end thought tokens so we don't accidentally sample them
2249
- new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
2250
-
2251
- for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
2252
- # Find the index of the last token that is not padding
2253
- base_answer_ids = input_ids[answer_idx]
2254
- new_answer_ids = new_ids[list_idx]
2255
- last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
2256
-
2257
- new_ids_sampled = torch.multinomial(
2258
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temp, dim=-1), 1)
2259
-
2260
- # Assign the new id to the last token
2261
- if last_token_idx + 1 >= len(base_answer_ids):
2262
- # Add padding everywhere
2263
- new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
2264
- device=input_ids.device)
2265
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
2266
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
2267
-
2268
- attention_mask[answer_idx, last_token_idx + 1] = 1
2269
- input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
2270
-
2271
- if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
2272
- finished_generating[answer_idx] = 1
2273
-
2274
- if finished_generating.all():
2275
- break
2276
-
2277
- streamer = generate_kwargs.get("streamer")
2278
- if streamer is not None:
2279
- streamer.put(input_ids)
2280
- streamer.end()
2281
-
2282
- return input_ids
2283
  # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Quiet, LLAMA->QUIET
2284
  class QuietForSequenceClassification(QuietPreTrainedModel):
2285
  def __init__(self, config):
 
1169
  def loss_mean(x):
1170
  return x.sum() / (x != 0).sum()
1171
 
1172
+ class QuietGenerationMixin(GenerationMixin):
1173
+ def generate(self, input_ids, attention_mask=None, **generate_kwargs):
1174
+ if attention_mask is None:
1175
+ attention_mask = torch.ones_like(input_ids)
1176
+
1177
+ max_length = generate_kwargs.get("max_length", 20)
1178
+ temp = generate_kwargs.get("temperature", 1.0)
1179
+
1180
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
1181
+
1182
+ for cur_token_idx in range(max_length):
1183
+ # Sample the next token
1184
+ new_ids = self(
1185
+ input_ids[~finished_generating],
1186
+ attention_mask=attention_mask[~finished_generating]
1187
+ )['logits']
1188
+
1189
+ # Mask out the start and end thought tokens so we don't accidentally sample them
1190
+ new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1191
+
1192
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1193
+ # Find the index of the last token that is not padding
1194
+ base_answer_ids = input_ids[answer_idx]
1195
+ new_answer_ids = new_ids[list_idx]
1196
+ last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1197
+
1198
+ new_ids_sampled = torch.multinomial(
1199
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temp, dim=-1), 1)
1200
+
1201
+ # Assign the new id to the last token
1202
+ if last_token_idx + 1 >= len(base_answer_ids):
1203
+ # Add padding everywhere
1204
+ new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
1205
+ device=input_ids.device)
1206
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
1207
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1208
+
1209
+ attention_mask[answer_idx, last_token_idx + 1] = 1
1210
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1211
+
1212
+ if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1213
+ finished_generating[answer_idx] = 1
1214
+
1215
+ if finished_generating.all():
1216
+ break
1217
+
1218
+ streamer = generate_kwargs.get("streamer")
1219
+ if streamer is not None:
1220
+ streamer.put(input_ids)
1221
+ streamer.end()
1222
+
1223
+ return input_ids
1224
+
1225
  class QuietForCausalLM(QuietPreTrainedModel, QuietGenerationMixin):
1226
  _tied_weights_keys = ["lm_head.weight"]
1227
 
 
2281
  QUIET_START_DOCSTRING,
2282
  )
2283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2284
  # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Quiet, LLAMA->QUIET
2285
  class QuietForSequenceClassification(QuietPreTrainedModel):
2286
  def __init__(self, config):