Crystalcareai commited on
Commit
bca4d85
·
verified ·
1 Parent(s): a3dfbb3

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +56 -46
modeling_quiet.py CHANGED
@@ -1423,61 +1423,71 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1423
  logits = self.lm_head(mixed_hidden_states)
1424
  return logits
1425
 
1426
- def custom_generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1427
  self,
1428
  input_ids,
1429
  attention_mask=None,
1430
- position_ids=None,
1431
- past_key_values=None,
1432
- inputs_embeds=None,
1433
- use_cache=None,
1434
- output_attentions=None,
1435
- output_hidden_states=None,
1436
- return_dict=None,
1437
- max_new_tokens=512,
1438
- temperature=1.1,
1439
  streamer=None,
1440
  **kwargs,
1441
  ):
 
1442
  batch_size, seq_len = input_ids.shape
1443
-
1444
- assert past_key_values is None, "past_key_values not supported yet"
1445
- assert position_ids is None, "position_ids not supported yet"
1446
 
1447
- # Generate up to max_new_tokens
1448
- for _ in range(max_new_tokens):
1449
- model_inputs = self.prepare_inputs_for_generation(
1450
- input_ids,
1451
- attention_mask=attention_mask,
1452
- inputs_embeds=inputs_embeds,
1453
- use_cache=use_cache,
1454
- )
1455
-
1456
- outputs = self.model(**model_inputs)
1457
- next_token_logits = self.infer(
1458
- input_ids=input_ids,
1459
- attention_mask=attention_mask,
1460
- position_ids=position_ids,
1461
- past_key_values=outputs.past_key_values,
1462
- inputs_embeds=inputs_embeds,
1463
- use_cache=use_cache,
1464
- output_attentions=output_attentions,
1465
- output_hidden_states=output_hidden_states,
1466
- return_dict=return_dict,
1467
- )
1468
-
1469
- next_token_logits = next_token_logits[:, -1, :]
1470
- next_tokens = torch.argmax(next_token_logits, dim=-1)
1471
-
1472
- input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
1473
-
1474
- if streamer is not None:
1475
- streamer.put(next_tokens)
1476
 
1477
- if next_tokens == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
1478
- break
1479
-
1480
- return input_ids
1481
 
1482
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1483
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1423
  logits = self.lm_head(mixed_hidden_states)
1424
  return logits
1425
 
1426
+
1427
+ def custom_generate(model, input_ids, attention_mask, max_length, streamer=None, **kwargs):
1428
+ # Set up some variables
1429
+ batch_size, seq_len = input_ids.shape
1430
+ max_length = max_length if max_length is not None else model.config.max_length
1431
+ max_new_tokens = max_length - seq_len
1432
+ temperature = kwargs.get("temperature", 1.0)
1433
+
1434
+ with torch.no_grad():
1435
+ for cur_token_idx in range(max_new_tokens):
1436
+ # Run a forward pass to get the logits for the next token
1437
+ outputs = model(
1438
+ input_ids=input_ids,
1439
+ attention_mask=attention_mask,
1440
+ use_cache=True,
1441
+ )
1442
+
1443
+ logits = outputs.logits[:, -1, :]
1444
+
1445
+ # Sample the next token from the logits
1446
+ next_token_logits = logits / temperature
1447
+ next_token_id = torch.multinomial(torch.nn.functional.softmax(next_token_logits, dim=-1), num_samples=1)
1448
+
1449
+ # Append the new token to the input sequence
1450
+ input_ids = torch.cat([input_ids, next_token_id], dim=-1)
1451
+ attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)], dim=-1)
1452
+
1453
+ # Stream the new token if a streamer is provided
1454
+ if streamer is not None:
1455
+ streamer.put(next_token_id)
1456
+
1457
+ # Check if the end token is generated for all sequences in the batch
1458
+ if next_token_id.eq(model.config.eos_token_id).all():
1459
+ break
1460
+
1461
+ return input_ids
1462
+
1463
+
1464
+ # Add this to QuietForCausalLM forward method to support custom generate
1465
+
1466
+ @torch.no_grad()
1467
+ def generate(
1468
  self,
1469
  input_ids,
1470
  attention_mask=None,
1471
+ max_length=None,
 
 
 
 
 
 
 
 
1472
  streamer=None,
1473
  **kwargs,
1474
  ):
1475
+ # Prepare inputs
1476
  batch_size, seq_len = input_ids.shape
1477
+ if attention_mask is None:
1478
+ attention_mask = torch.ones_like(input_ids)
 
1479
 
1480
+ # Call the custom generate function
1481
+ output_ids = custom_generate(
1482
+ self,
1483
+ input_ids=input_ids,
1484
+ attention_mask=attention_mask,
1485
+ max_length=max_length,
1486
+ streamer=streamer,
1487
+ **kwargs,
1488
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1489
 
1490
+ return output_ids
 
 
 
1491
 
1492
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1493
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)