Update modeling_quiet.py
Browse files- modeling_quiet.py +2 -3
modeling_quiet.py
CHANGED
|
@@ -1430,7 +1430,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
| 1430 |
input_ids: torch.LongTensor,
|
| 1431 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1432 |
max_new_tokens: Optional[int] = None,
|
| 1433 |
-
temperature: float =
|
| 1434 |
**kwargs,
|
| 1435 |
):
|
| 1436 |
if attention_mask is None:
|
|
@@ -1438,8 +1438,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
| 1438 |
attention_mask = torch.ones_like(input_ids)
|
| 1439 |
|
| 1440 |
from .generate import generate
|
| 1441 |
-
|
| 1442 |
-
return generated_token_ids, attention_mask
|
| 1443 |
|
| 1444 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
| 1445 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
|
| 1430 |
input_ids: torch.LongTensor,
|
| 1431 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1432 |
max_new_tokens: Optional[int] = None,
|
| 1433 |
+
temperature: float = 1.1,
|
| 1434 |
**kwargs,
|
| 1435 |
):
|
| 1436 |
if attention_mask is None:
|
|
|
|
| 1438 |
attention_mask = torch.ones_like(input_ids)
|
| 1439 |
|
| 1440 |
from .generate import generate
|
| 1441 |
+
return generate(self, input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
|
|
|
| 1442 |
|
| 1443 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
| 1444 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|