Crystalcareai commited on
Commit
9530965
·
verified ·
1 Parent(s): f0d7787

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +17 -20
modeling_quiet.py CHANGED
@@ -1110,34 +1110,31 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1110
  # Apply the language model head to get the final logits
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
-
1114
- @torch.no_grad()
1115
- def generate(
1116
- self,
1117
- input_ids: torch.LongTensor = torch.LongTensor(),
1118
- attention_mask: Optional[torch.Tensor] = None,
1119
- max_new_tokens: Optional[int] = None,
1120
- temperature: float = 1.1,
1121
- **kwargs,
1122
- ):
1123
  if isinstance(input_ids, str):
1124
  input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1125
 
1126
  if attention_mask is None:
1127
- # Create a default attention mask if not provided
1128
  attention_mask = torch.ones_like(input_ids)
1129
 
1130
  from .generate import generate
1131
  generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1132
-
1133
- # Convert the generated token IDs to a tensor
1134
- generated_token_ids = torch.tensor(generated_token_ids)
1135
-
1136
- # Return the generated text if it's a string, otherwise return the token IDs
1137
- if isinstance(generated_text, str):
1138
- return generated_text
1139
- else:
1140
- return generated_token_ids
 
 
 
 
 
 
1141
 
1142
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1143
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
 
1110
  # Apply the language model head to get the final logits
1111
  logits = self.lm_head(mixed_hidden_states)
1112
  return logits
1113
+
1114
+ def generate_with_callback(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, callback=None, **kwargs):
 
 
 
 
 
 
 
 
1115
  if isinstance(input_ids, str):
1116
  input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
1117
 
1118
  if attention_mask is None:
 
1119
  attention_mask = torch.ones_like(input_ids)
1120
 
1121
  from .generate import generate
1122
  generated_token_ids, generated_text = generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
1123
+
1124
+ if callback is not None:
1125
+ callback(generated_text)
1126
+
1127
+ return generated_text
1128
+
1129
+ @torch.no_grad()
1130
+ def generate(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, **kwargs):
1131
+ return self.generate_with_callback(input_ids, attention_mask, max_new_tokens, temperature, callback=None, **kwargs)
1132
+
1133
+ def generate_with_streaming(self, input_ids: torch.LongTensor = torch.LongTensor(), attention_mask: Optional[torch.Tensor] = None, max_new_tokens: Optional[int] = None, temperature: float = 1.1, **kwargs):
1134
+ def callback(generated_text):
1135
+ yield generated_text
1136
+
1137
+ return self.generate_with_callback(input_ids, attention_mask, max_new_tokens, temperature, callback=callback, **kwargs)
1138
 
1139
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
1140
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)