Update modeling_quiet.py
Browse files- modeling_quiet.py +56 -46
modeling_quiet.py
CHANGED
@@ -1423,61 +1423,71 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1423 |
logits = self.lm_head(mixed_hidden_states)
|
1424 |
return logits
|
1425 |
|
1426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1427 |
self,
|
1428 |
input_ids,
|
1429 |
attention_mask=None,
|
1430 |
-
|
1431 |
-
past_key_values=None,
|
1432 |
-
inputs_embeds=None,
|
1433 |
-
use_cache=None,
|
1434 |
-
output_attentions=None,
|
1435 |
-
output_hidden_states=None,
|
1436 |
-
return_dict=None,
|
1437 |
-
max_new_tokens=512,
|
1438 |
-
temperature=1.1,
|
1439 |
streamer=None,
|
1440 |
**kwargs,
|
1441 |
):
|
|
|
1442 |
batch_size, seq_len = input_ids.shape
|
1443 |
-
|
1444 |
-
|
1445 |
-
assert position_ids is None, "position_ids not supported yet"
|
1446 |
|
1447 |
-
#
|
1448 |
-
|
1449 |
-
|
1450 |
-
|
1451 |
-
|
1452 |
-
|
1453 |
-
|
1454 |
-
|
1455 |
-
|
1456 |
-
outputs = self.model(**model_inputs)
|
1457 |
-
next_token_logits = self.infer(
|
1458 |
-
input_ids=input_ids,
|
1459 |
-
attention_mask=attention_mask,
|
1460 |
-
position_ids=position_ids,
|
1461 |
-
past_key_values=outputs.past_key_values,
|
1462 |
-
inputs_embeds=inputs_embeds,
|
1463 |
-
use_cache=use_cache,
|
1464 |
-
output_attentions=output_attentions,
|
1465 |
-
output_hidden_states=output_hidden_states,
|
1466 |
-
return_dict=return_dict,
|
1467 |
-
)
|
1468 |
-
|
1469 |
-
next_token_logits = next_token_logits[:, -1, :]
|
1470 |
-
next_tokens = torch.argmax(next_token_logits, dim=-1)
|
1471 |
-
|
1472 |
-
input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1)
|
1473 |
-
|
1474 |
-
if streamer is not None:
|
1475 |
-
streamer.put(next_tokens)
|
1476 |
|
1477 |
-
|
1478 |
-
break
|
1479 |
-
|
1480 |
-
return input_ids
|
1481 |
|
1482 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1483 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1423 |
logits = self.lm_head(mixed_hidden_states)
|
1424 |
return logits
|
1425 |
|
1426 |
+
|
1427 |
+
def custom_generate(model, input_ids, attention_mask, max_length, streamer=None, **kwargs):
|
1428 |
+
# Set up some variables
|
1429 |
+
batch_size, seq_len = input_ids.shape
|
1430 |
+
max_length = max_length if max_length is not None else model.config.max_length
|
1431 |
+
max_new_tokens = max_length - seq_len
|
1432 |
+
temperature = kwargs.get("temperature", 1.0)
|
1433 |
+
|
1434 |
+
with torch.no_grad():
|
1435 |
+
for cur_token_idx in range(max_new_tokens):
|
1436 |
+
# Run a forward pass to get the logits for the next token
|
1437 |
+
outputs = model(
|
1438 |
+
input_ids=input_ids,
|
1439 |
+
attention_mask=attention_mask,
|
1440 |
+
use_cache=True,
|
1441 |
+
)
|
1442 |
+
|
1443 |
+
logits = outputs.logits[:, -1, :]
|
1444 |
+
|
1445 |
+
# Sample the next token from the logits
|
1446 |
+
next_token_logits = logits / temperature
|
1447 |
+
next_token_id = torch.multinomial(torch.nn.functional.softmax(next_token_logits, dim=-1), num_samples=1)
|
1448 |
+
|
1449 |
+
# Append the new token to the input sequence
|
1450 |
+
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
|
1451 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1), dtype=attention_mask.dtype, device=attention_mask.device)], dim=-1)
|
1452 |
+
|
1453 |
+
# Stream the new token if a streamer is provided
|
1454 |
+
if streamer is not None:
|
1455 |
+
streamer.put(next_token_id)
|
1456 |
+
|
1457 |
+
# Check if the end token is generated for all sequences in the batch
|
1458 |
+
if next_token_id.eq(model.config.eos_token_id).all():
|
1459 |
+
break
|
1460 |
+
|
1461 |
+
return input_ids
|
1462 |
+
|
1463 |
+
|
1464 |
+
# Add this to QuietForCausalLM forward method to support custom generate
|
1465 |
+
|
1466 |
+
@torch.no_grad()
|
1467 |
+
def generate(
|
1468 |
self,
|
1469 |
input_ids,
|
1470 |
attention_mask=None,
|
1471 |
+
max_length=None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1472 |
streamer=None,
|
1473 |
**kwargs,
|
1474 |
):
|
1475 |
+
# Prepare inputs
|
1476 |
batch_size, seq_len = input_ids.shape
|
1477 |
+
if attention_mask is None:
|
1478 |
+
attention_mask = torch.ones_like(input_ids)
|
|
|
1479 |
|
1480 |
+
# Call the custom generate function
|
1481 |
+
output_ids = custom_generate(
|
1482 |
+
self,
|
1483 |
+
input_ids=input_ids,
|
1484 |
+
attention_mask=attention_mask,
|
1485 |
+
max_length=max_length,
|
1486 |
+
streamer=streamer,
|
1487 |
+
**kwargs,
|
1488 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1489 |
|
1490 |
+
return output_ids
|
|
|
|
|
|
|
1491 |
|
1492 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1493 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|