Crystalcareai commited on
Commit
1a7d227
·
verified ·
1 Parent(s): 9b89e6b

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +2 -3
modeling_quiet.py CHANGED
@@ -1430,7 +1430,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1430
  input_ids: torch.LongTensor,
1431
  attention_mask: Optional[torch.Tensor] = None,
1432
  max_new_tokens: Optional[int] = None,
1433
- temperature: float = 0.9,
1434
  **kwargs,
1435
  ):
1436
  if attention_mask is None:
@@ -1438,8 +1438,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1438
  attention_mask = torch.ones_like(input_ids)
1439
 
1440
  from .generate import generate
1441
- generated_token_ids, attention_mask = generate(self, input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1442
- return generated_token_ids, attention_mask
1443
 
1444
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1445
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1430
  input_ids: torch.LongTensor,
1431
  attention_mask: Optional[torch.Tensor] = None,
1432
  max_new_tokens: Optional[int] = None,
1433
+ temperature: float = 1.1,
1434
  **kwargs,
1435
  ):
1436
  if attention_mask is None:
 
1438
  attention_mask = torch.ones_like(input_ids)
1439
 
1440
  from .generate import generate
1441
+ return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
 
1442
 
1443
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1444
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)