Crystalcareai
commited on
Update modeling_quiet.py
Browse files- modeling_quiet.py +17 -20
modeling_quiet.py
CHANGED
@@ -1110,34 +1110,31 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1110 |
# Apply the language model head to get the final logits
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
-
|
1114 |
-
|
1115 |
-
def generate(
|
1116 |
-
self,
|
1117 |
-
input_ids: torch.LongTensor = torch.LongTensor(),
|
1118 |
-
attention_mask: Optional[torch.Tensor] = None,
|
1119 |
-
max_new_tokens: Optional[int] = None,
|
1120 |
-
temperature: float = 1.1,
|
1121 |
-
**kwargs,
|
1122 |
-
):
|
1123 |
if isinstance(input_ids, str):
|
1124 |
input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
1125 |
|
1126 |
if attention_mask is None:
|
1127 |
-
# Create a default attention mask if not provided
|
1128 |
attention_mask = torch.ones_like(input_ids)
|
1129 |
|
1130 |
from .generate import generate
|
1131 |
generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
1132 |
-
|
1133 |
-
|
1134 |
-
|
1135 |
-
|
1136 |
-
|
1137 |
-
|
1138 |
-
|
1139 |
-
|
1140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1141 |
|
1142 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1143 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
1110 |
# Apply the language model head to get the final logits
|
1111 |
logits = self.lm_head(mixed_hidden_states)
|
1112 |
return logits
|
1113 |
+
|
1114 |
+
def generate_with_callback(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, callback=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1115 |
if isinstance(input_ids, str):
|
1116 |
input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
1117 |
|
1118 |
if attention_mask is None:
|
|
|
1119 |
attention_mask = torch.ones_like(input_ids)
|
1120 |
|
1121 |
from .generate import generate
|
1122 |
generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
1123 |
+
|
1124 |
+
if callback is not None:
|
1125 |
+
callback(generated_text)
|
1126 |
+
|
1127 |
+
return generated_text
|
1128 |
+
|
1129 |
+
@torch.no_grad()
|
1130 |
+
def generate(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, **kwargs):
|
1131 |
+
return self.generate_with_callback(input_ids, attention_mask, max_new_tokens, temperature, callback=None, **kwargs)
|
1132 |
+
|
1133 |
+
def generate_with_streaming(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, **kwargs):
|
1134 |
+
def callback(generated_text):
|
1135 |
+
yield generated_text
|
1136 |
+
|
1137 |
+
return self.generate_with_callback(input_ids, attention_mask, max_new_tokens, temperature, callback=callback, **kwargs)
|
1138 |
|
1139 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
1140 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|