Crystalcareai commited on
Commit
fc8abcb
·
verified ·
1 Parent(s): fea5225

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +62 -171
modeling_quiet.py CHANGED
@@ -1169,60 +1169,7 @@ def nonzero_mean(x, axis=None):
1169
  def loss_mean(x):
1170
  return x.sum() / (x != 0).sum()
1171
 
1172
- class QuietGenerationMixin(GenerationMixin):
1173
- def generate(self, input_ids, attention_mask=None, **generate_kwargs):
1174
- if attention_mask is None:
1175
- attention_mask = torch.ones_like(input_ids)
1176
-
1177
- max_length = generate_kwargs.get("max_length", 20)
1178
- temp = generate_kwargs.get("temperature", 1.0)
1179
-
1180
- finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
1181
-
1182
- for cur_token_idx in range(max_length):
1183
- # Sample the next token
1184
- new_ids = self(
1185
- input_ids[~finished_generating],
1186
- attention_mask=attention_mask[~finished_generating]
1187
- )['logits']
1188
-
1189
- # Mask out the start and end thought tokens so we don't accidentally sample them
1190
- new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1191
-
1192
- for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1193
- # Find the index of the last token that is not padding
1194
- base_answer_ids = input_ids[answer_idx]
1195
- new_answer_ids = new_ids[list_idx]
1196
- last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1197
-
1198
- new_ids_sampled = torch.multinomial(
1199
- torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temp, dim=-1), 1)
1200
-
1201
- # Assign the new id to the last token
1202
- if last_token_idx + 1 >= len(base_answer_ids):
1203
- # Add padding everywhere
1204
- new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
1205
- device=input_ids.device)
1206
- input_ids = torch.cat([input_ids, new_padding], dim=-1)
1207
- attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1208
-
1209
- attention_mask[answer_idx, last_token_idx + 1] = 1
1210
- input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1211
-
1212
- if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1213
- finished_generating[answer_idx] = 1
1214
-
1215
- if finished_generating.all():
1216
- break
1217
-
1218
- streamer = generate_kwargs.get("streamer")
1219
- if streamer is not None:
1220
- streamer.put(input_ids)
1221
- streamer.end()
1222
-
1223
- return input_ids
1224
-
1225
- class QuietForCausalLM(QuietPreTrainedModel, QuietGenerationMixin):
1226
  _tied_weights_keys = ["lm_head.weight"]
1227
 
1228
  def __init__(self, config):
@@ -1377,91 +1324,46 @@ class QuietForCausalLM(QuietPreTrainedModel, QuietGenerationMixin):
1377
  elif isinstance(module, nn.Embedding):
1378
  nn.init.xavier_uniform_(module.weight)
1379
 
1380
- # @torch.no_grad()
1381
- # def generate(
1382
- # self,
1383
- # input_ids: torch.LongTensor,
1384
- # attention_mask: Optional[torch.LongTensor] = None,
1385
- # **generate_kwargs,
1386
- # ) -> torch.LongTensor:
1387
- # n_ahead = 8
1388
- # n_ahead_talk = 4
1389
- # merged_talk_heads = True
1390
-
1391
- # if attention_mask is None:
1392
- # attention_mask = torch.ones_like(input_ids)
1393
-
1394
- # generate_kwargs.update({
1395
- # "max_thoughts": n_ahead + n_ahead_talk + 1,
1396
- # "merged_talk_heads": merged_talk_heads,
1397
- # "merged_lm_and_talk_heads": False,
1398
- # "merged_lm_and_think_heads": True,
1399
- # "use_concat_talk_head": True,
1400
- # "use_shallow_think": True,
1401
- # "use_shallow_talk": False,
1402
- # "use_complex_think_head": False,
1403
- # "use_complex_talk_head": True,
1404
- # "use_weighted_talk_head": True,
1405
- # })
1406
 
1407
- # # Validate stopping criteria
1408
- # stopping_criteria = generate_kwargs.pop("stopping_criteria", None)
1409
- # if stopping_criteria is not None:
1410
- # stopping_criteria = validate_stopping_criteria(
1411
- # stopping_criteria,
1412
- # self.config,
1413
- # )
1414
- # stopping_criteria = StoppingCriteriaList(stopping_criteria)
1415
- # else:
1416
- # stopping_criteria = StoppingCriteriaList()
1417
-
1418
- # streamer = generate_kwargs.pop("streamer", None)
1419
- # temp = generate_kwargs.pop("temperature", 1.0)
1420
- # max_length = generate_kwargs.pop("max_length", 20)
1421
-
1422
- # finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
1423
-
1424
- # for cur_token_idx in range(max_length):
1425
- # # Sample the next token
1426
- # new_ids = self(
1427
- # input_ids[~finished_generating],
1428
- # attention_mask=attention_mask[~finished_generating]
1429
- # )['logits']
1430
-
1431
- # # Mask out the start and end thought tokens so we don't accidentally sample them
1432
- # new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1433
-
1434
- # for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1435
- # # Find the index of the last token that is not padding
1436
- # base_answer_ids = input_ids[answer_idx]
1437
- # new_answer_ids = new_ids[list_idx]
1438
- # last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1439
-
1440
- # new_ids_sampled = torch.multinomial(
1441
- # torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temp, dim=-1), 1)
1442
-
1443
- # # Assign the new id to the last token
1444
- # if last_token_idx + 1 >= len(base_answer_ids):
1445
- # # Add padding everywhere
1446
- # new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
1447
- # device=input_ids.device)
1448
- # input_ids = torch.cat([input_ids, new_padding], dim=-1)
1449
- # attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1450
-
1451
- # attention_mask[answer_idx, last_token_idx + 1] = 1
1452
- # input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1453
-
1454
- # if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1455
- # finished_generating[answer_idx] = 1
1456
-
1457
- # if finished_generating.all():
1458
- # break
1459
-
1460
- # if streamer is not None:
1461
- # streamer.put(input_ids)
1462
- # streamer.end()
1463
-
1464
- # return input_ids
1465
 
1466
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1467
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
@@ -1474,41 +1376,31 @@ class QuietForCausalLM(QuietPreTrainedModel, QuietGenerationMixin):
1474
  inputs_embeds: Optional[torch.FloatTensor] = None,
1475
  labels: Optional[torch.LongTensor] = None,
1476
  use_cache: Optional[bool] = None,
 
1477
  output_attentions: Optional[bool] = None,
1478
  output_hidden_states: Optional[bool] = None,
1479
  return_dict: Optional[bool] = None,
1480
- **kwargs,
1481
- ):
1482
- n_ahead = 8
1483
- n_ahead_talk = 4
1484
- merged_talk_heads = True
1485
-
1486
- kwargs.update({
1487
- "max_thoughts": n_ahead + n_ahead_talk + 1,
1488
- "merged_talk_heads": merged_talk_heads,
1489
- "merged_lm_and_talk_heads": False,
1490
- "merged_lm_and_think_heads": True,
1491
- "use_concat_talk_head": True,
1492
- "use_shallow_think": True,
1493
- "use_shallow_talk": False,
1494
- "use_complex_think_head": False,
1495
- "use_complex_talk_head": True,
1496
- "use_weighted_talk_head": True,
1497
- })
1498
-
1499
- return super().forward(
1500
- input_ids=input_ids,
1501
- attention_mask=attention_mask,
1502
- position_ids=position_ids,
1503
- past_key_values=past_key_values,
1504
- inputs_embeds=inputs_embeds,
1505
- labels=labels,
1506
- use_cache=use_cache,
1507
- output_attentions=output_attentions,
1508
- output_hidden_states=output_hidden_states,
1509
- return_dict=return_dict,
1510
- **kwargs,
1511
- )
1512
  if not self.training:
1513
  n_ahead_talk_to_restore = self.n_ahead_talk
1514
  n_passes_to_restore = self.n_passes
@@ -2280,7 +2172,6 @@ class QuietForCausalLM(QuietPreTrainedModel, QuietGenerationMixin):
2280
  """,
2281
  QUIET_START_DOCSTRING,
2282
  )
2283
-
2284
  # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Quiet, LLAMA->QUIET
2285
  class QuietForSequenceClassification(QuietPreTrainedModel):
2286
  def __init__(self, config):
 
1169
  def loss_mean(x):
1170
  return x.sum() / (x != 0).sum()
1171
 
1172
+ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1173
  _tied_weights_keys = ["lm_head.weight"]
1174
 
1175
  def __init__(self, config):
 
1324
  elif isinstance(module, nn.Embedding):
1325
  nn.init.xavier_uniform_(module.weight)
1326
 
1327
+ @torch.no_grad()
1328
+ def generate(self, input_ids, attention_mask=None, streamer=None, **kwargs):
1329
+ if attention_mask is None:
1330
+ attention_mask = torch.ones_like(input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1331
 
1332
+ max_length = kwargs.get("max_length", 20)
1333
+ temp = kwargs.get("temperature", 1.0)
1334
+
1335
+ with torch.no_grad():
1336
+ finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
1337
+ for cur_token_idx in range(max_length):
1338
+ # Sample the next token
1339
+ new_ids = self(
1340
+ input_ids[~finished_generating],
1341
+ attention_mask=attention_mask[~finished_generating]
1342
+ )['logits']
1343
+ # Mask out the start and end thought tokens so we don't accidentally sample them
1344
+ new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
1345
+ for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
1346
+ # Find the index of the last token that is not padding
1347
+ base_answer_ids = input_ids[answer_idx]
1348
+ new_answer_ids = new_ids[list_idx]
1349
+ last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
1350
+
1351
+ new_ids_sampled = torch.multinomial(
1352
+ torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temp, dim=-1), 1)
1353
+ # Assign the new id to the last token
1354
+ if last_token_idx + 1 >= len(base_answer_ids):
1355
+ # Add padding everywhere
1356
+ new_padding = torch.full((len(input_ids), 1), self.tokenizer.pad_token_id, dtype=torch.long,
1357
+ device=input_ids.device)
1358
+ input_ids = torch.cat([input_ids, new_padding], dim=-1)
1359
+ attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
1360
+ attention_mask[answer_idx, last_token_idx + 1] = 1
1361
+ input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
1362
+ if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
1363
+ finished_generating[answer_idx] = 1
1364
+ if finished_generating.all():
1365
+ break
1366
+ return input_ids, attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1367
 
1368
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1369
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1376
  inputs_embeds: Optional[torch.FloatTensor] = None,
1377
  labels: Optional[torch.LongTensor] = None,
1378
  use_cache: Optional[bool] = None,
1379
+ # output_router_logits: Optional[bool] = None,
1380
  output_attentions: Optional[bool] = None,
1381
  output_hidden_states: Optional[bool] = None,
1382
  return_dict: Optional[bool] = None,
1383
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1384
+ r"""
1385
+ Args:
1386
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1387
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1388
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1389
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1390
+ Returns:
1391
+ Example:
1392
+ ```python
1393
+ >>> from transformers import AutoTokenizer, QuietForCausalLM
1394
+ >>> model = QuietForCausalLM.from_pretrained("quietai/Quiet-7B-v0.1")
1395
+ >>> tokenizer = AutoTokenizer.from_pretrained("quietai/Quiet-7B-v0.1")
1396
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
1397
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
1398
+ >>> # Generate
1399
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1400
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1401
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
1402
+ ```"""
1403
+
 
 
 
 
 
 
 
 
 
 
 
1404
  if not self.training:
1405
  n_ahead_talk_to_restore = self.n_ahead_talk
1406
  n_passes_to_restore = self.n_passes
 
2172
  """,
2173
  QUIET_START_DOCSTRING,
2174
  )
 
2175
  # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Quiet, LLAMA->QUIET
2176
  class QuietForSequenceClassification(QuietPreTrainedModel):
2177
  def __init__(self, config):