Crystalcareai commited on
Commit
0c70c53
·
verified ·
1 Parent(s): ed908ff

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +104 -30
modeling_quiet.py CHANGED
@@ -1425,6 +1425,54 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
1425
  logits = self.lm_head(mixed_hidden_states)
1426
  return logits
1427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1428
 
1429
 
1430
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
@@ -2159,36 +2207,62 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
2159
  return rare_token_ids
2160
 
2161
 
2162
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
2163
- attention_mask = kwargs.get("attention_mask", None)
2164
- position_ids = kwargs.get("position_ids", None)
2165
- inputs_embeds = kwargs.get("inputs_embeds", None)
2166
- use_cache = kwargs.get("use_cache", None)
2167
- output_attentions = kwargs.get("output_attentions", None)
2168
- output_hidden_states = kwargs.get("output_hidden_states", None)
2169
- return_dict = kwargs.get("return_dict", None)
2170
-
2171
- # Call the custom infer method
2172
- logits = self.infer(
2173
- input_ids=input_ids,
2174
- attention_mask=attention_mask,
2175
- position_ids=position_ids,
2176
- past_key_values=past_key_values,
2177
- inputs_embeds=inputs_embeds,
2178
- use_cache=use_cache,
2179
- output_attentions=output_attentions,
2180
- output_hidden_states=output_hidden_states,
2181
- return_dict=return_dict,
2182
- )
2183
-
2184
- # Return the prepared inputs for generation
2185
- return {
2186
- "input_ids": input_ids,
2187
- "logits": logits,
2188
- "past_key_values": past_key_values,
2189
- "attention_mask": attention_mask,
2190
- "position_ids": position_ids,
2191
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2192
 
2193
  @staticmethod
2194
  def _reorder_cache(past_key_values, beam_idx):
 
1425
  logits = self.lm_head(mixed_hidden_states)
1426
  return logits
1427
 
1428
+ @torch.no_grad()
1429
+ def generate(
1430
+ self,
1431
+ inputs: Optional[torch.Tensor] = None,
1432
+ max_length: Optional[int] = None,
1433
+ min_length: Optional[int] = None,
1434
+ do_sample: Optional[bool] = None,
1435
+ early_stopping: Optional[bool] = None,
1436
+ num_beams: Optional[int] = None,
1437
+ temperature: Optional[float] = None,
1438
+ top_k: Optional[int] = None,
1439
+ top_p: Optional[float] = None,
1440
+ repetition_penalty: Optional[float] = None,
1441
+ pad_token_id: Optional[int] = None,
1442
+ bos_token_id: Optional[int] = None,
1443
+ eos_token_id: Optional[int] = None,
1444
+ length_penalty: Optional[float] = None,
1445
+ no_repeat_ngram_size: Optional[int] = None,
1446
+ bad_words_ids: Optional[Iterable[int]] = None,
1447
+ num_return_sequences: Optional[int] = None,
1448
+ decoder_start_token_id: Optional[int] = None,
1449
+ use_cache: Optional[bool] = None,
1450
+ stopping_criteria: Optional["StoppingCriteriaList"] = None,
1451
+ **model_kwargs,
1452
+ ) -> torch.LongTensor:
1453
+ # Validate stopping criteria
1454
+ stopping_criteria = validate_stopping_criteria(stopping_criteria)
1455
+
1456
+ # Prepare inputs
1457
+ input_ids = inputs["input_ids"] if "input_ids" in inputs else inputs
1458
+ attention_mask = inputs["attention_mask"] if "attention_mask" in inputs else None
1459
+ position_ids = inputs["position_ids"] if "position_ids" in inputs else None
1460
+ past_key_values = inputs["past_key_values"] if "past_key_values" in inputs else None
1461
+ inputs_embeds = inputs["inputs_embeds"] if "inputs_embeds" in inputs else None
1462
+
1463
+ # Call the infer function
1464
+ logits = self.infer(
1465
+ input_ids=input_ids,
1466
+ attention_mask=attention_mask,
1467
+ position_ids=position_ids,
1468
+ past_key_values=past_key_values,
1469
+ inputs_embeds=inputs_embeds,
1470
+ use_cache=use_cache,
1471
+ )
1472
+
1473
+ # Return the generated logits
1474
+ return logits
1475
+
1476
 
1477
 
1478
  @add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
 
2207
  return rare_token_ids
2208
 
2209
 
2210
+ def prepare_inputs_for_generation(
2211
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
2212
+ ):
2213
+ # Omit tokens covered by past_key_values
2214
+ if past_key_values is not None:
2215
+ if isinstance(past_key_values, Cache):
2216
+ cache_length = past_key_values.get_seq_length()
2217
+ past_length = past_key_values.seen_tokens
2218
+ max_cache_length = past_key_values.get_max_length()
2219
+ else:
2220
+ cache_length = past_length = past_key_values[0][0].shape[2]
2221
+ max_cache_length = None
2222
+
2223
+ # Keep only the unprocessed tokens:
2224
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
2225
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
2226
+ # input)
2227
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
2228
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
2229
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
2230
+ # input_ids based on the past_length.
2231
+ elif past_length < input_ids.shape[1]:
2232
+ input_ids = input_ids[:, past_length:]
2233
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
2234
+
2235
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
2236
+ if (
2237
+ max_cache_length is not None
2238
+ and attention_mask is not None
2239
+ and cache_length + input_ids.shape[1] > max_cache_length
2240
+ ):
2241
+ attention_mask = attention_mask[:, -max_cache_length:]
2242
+
2243
+ position_ids = kwargs.get("position_ids", None)
2244
+ if attention_mask is not None and position_ids is None:
2245
+ # create position_ids on the fly for batch generation
2246
+ position_ids = attention_mask.long().cumsum(-1) - 1
2247
+ position_ids.masked_fill_(attention_mask == 0, 1)
2248
+ if past_key_values:
2249
+ position_ids = position_ids[:, -input_ids.shape[1] :]
2250
+
2251
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
2252
+ if inputs_embeds is not None and past_key_values is None:
2253
+ model_inputs = {"inputs_embeds": inputs_embeds}
2254
+ else:
2255
+ model_inputs = {"input_ids": input_ids}
2256
+
2257
+ model_inputs.update(
2258
+ {
2259
+ "position_ids": position_ids,
2260
+ "past_key_values": past_key_values,
2261
+ "use_cache": kwargs.get("use_cache"),
2262
+ "attention_mask": attention_mask,
2263
+ }
2264
+ )
2265
+ return model_inputs
2266
 
2267
  @staticmethod
2268
  def _reorder_cache(past_key_values, beam_idx):