Update modeling_quiet.py
Browse files- modeling_quiet.py +104 -30
modeling_quiet.py
CHANGED
@@ -1425,6 +1425,54 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
1425 |
logits = self.lm_head(mixed_hidden_states)
|
1426 |
return logits
|
1427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1428 |
|
1429 |
|
1430 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
@@ -2159,36 +2207,62 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
2159 |
return rare_token_ids
|
2160 |
|
2161 |
|
2162 |
-
def prepare_inputs_for_generation(
|
2163 |
-
|
2164 |
-
|
2165 |
-
|
2166 |
-
|
2167 |
-
|
2168 |
-
|
2169 |
-
|
2170 |
-
|
2171 |
-
|
2172 |
-
|
2173 |
-
|
2174 |
-
|
2175 |
-
|
2176 |
-
|
2177 |
-
|
2178 |
-
|
2179 |
-
|
2180 |
-
|
2181 |
-
|
2182 |
-
|
2183 |
-
|
2184 |
-
|
2185 |
-
|
2186 |
-
|
2187 |
-
|
2188 |
-
|
2189 |
-
|
2190 |
-
|
2191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2192 |
|
2193 |
@staticmethod
|
2194 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
1425 |
logits = self.lm_head(mixed_hidden_states)
|
1426 |
return logits
|
1427 |
|
1428 |
+
@torch.no_grad()
|
1429 |
+
def generate(
|
1430 |
+
self,
|
1431 |
+
inputs: Optional[torch.Tensor] = None,
|
1432 |
+
max_length: Optional[int] = None,
|
1433 |
+
min_length: Optional[int] = None,
|
1434 |
+
do_sample: Optional[bool] = None,
|
1435 |
+
early_stopping: Optional[bool] = None,
|
1436 |
+
num_beams: Optional[int] = None,
|
1437 |
+
temperature: Optional[float] = None,
|
1438 |
+
top_k: Optional[int] = None,
|
1439 |
+
top_p: Optional[float] = None,
|
1440 |
+
repetition_penalty: Optional[float] = None,
|
1441 |
+
pad_token_id: Optional[int] = None,
|
1442 |
+
bos_token_id: Optional[int] = None,
|
1443 |
+
eos_token_id: Optional[int] = None,
|
1444 |
+
length_penalty: Optional[float] = None,
|
1445 |
+
no_repeat_ngram_size: Optional[int] = None,
|
1446 |
+
bad_words_ids: Optional[Iterable[int]] = None,
|
1447 |
+
num_return_sequences: Optional[int] = None,
|
1448 |
+
decoder_start_token_id: Optional[int] = None,
|
1449 |
+
use_cache: Optional[bool] = None,
|
1450 |
+
stopping_criteria: Optional["StoppingCriteriaList"] = None,
|
1451 |
+
**model_kwargs,
|
1452 |
+
) -> torch.LongTensor:
|
1453 |
+
# Validate stopping criteria
|
1454 |
+
stopping_criteria = validate_stopping_criteria(stopping_criteria)
|
1455 |
+
|
1456 |
+
# Prepare inputs
|
1457 |
+
input_ids = inputs["input_ids"] if "input_ids" in inputs else inputs
|
1458 |
+
attention_mask = inputs["attention_mask"] if "attention_mask" in inputs else None
|
1459 |
+
position_ids = inputs["position_ids"] if "position_ids" in inputs else None
|
1460 |
+
past_key_values = inputs["past_key_values"] if "past_key_values" in inputs else None
|
1461 |
+
inputs_embeds = inputs["inputs_embeds"] if "inputs_embeds" in inputs else None
|
1462 |
+
|
1463 |
+
# Call the infer function
|
1464 |
+
logits = self.infer(
|
1465 |
+
input_ids=input_ids,
|
1466 |
+
attention_mask=attention_mask,
|
1467 |
+
position_ids=position_ids,
|
1468 |
+
past_key_values=past_key_values,
|
1469 |
+
inputs_embeds=inputs_embeds,
|
1470 |
+
use_cache=use_cache,
|
1471 |
+
)
|
1472 |
+
|
1473 |
+
# Return the generated logits
|
1474 |
+
return logits
|
1475 |
+
|
1476 |
|
1477 |
|
1478 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
|
|
2207 |
return rare_token_ids
|
2208 |
|
2209 |
|
2210 |
+
def prepare_inputs_for_generation(
|
2211 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
2212 |
+
):
|
2213 |
+
# Omit tokens covered by past_key_values
|
2214 |
+
if past_key_values is not None:
|
2215 |
+
if isinstance(past_key_values, Cache):
|
2216 |
+
cache_length = past_key_values.get_seq_length()
|
2217 |
+
past_length = past_key_values.seen_tokens
|
2218 |
+
max_cache_length = past_key_values.get_max_length()
|
2219 |
+
else:
|
2220 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
2221 |
+
max_cache_length = None
|
2222 |
+
|
2223 |
+
# Keep only the unprocessed tokens:
|
2224 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
2225 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
|
2226 |
+
# input)
|
2227 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
2228 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
2229 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
2230 |
+
# input_ids based on the past_length.
|
2231 |
+
elif past_length < input_ids.shape[1]:
|
2232 |
+
input_ids = input_ids[:, past_length:]
|
2233 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
2234 |
+
|
2235 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
2236 |
+
if (
|
2237 |
+
max_cache_length is not None
|
2238 |
+
and attention_mask is not None
|
2239 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
2240 |
+
):
|
2241 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
2242 |
+
|
2243 |
+
position_ids = kwargs.get("position_ids", None)
|
2244 |
+
if attention_mask is not None and position_ids is None:
|
2245 |
+
# create position_ids on the fly for batch generation
|
2246 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
2247 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
2248 |
+
if past_key_values:
|
2249 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
2250 |
+
|
2251 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
2252 |
+
if inputs_embeds is not None and past_key_values is None:
|
2253 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
2254 |
+
else:
|
2255 |
+
model_inputs = {"input_ids": input_ids}
|
2256 |
+
|
2257 |
+
model_inputs.update(
|
2258 |
+
{
|
2259 |
+
"position_ids": position_ids,
|
2260 |
+
"past_key_values": past_key_values,
|
2261 |
+
"use_cache": kwargs.get("use_cache"),
|
2262 |
+
"attention_mask": attention_mask,
|
2263 |
+
}
|
2264 |
+
)
|
2265 |
+
return model_inputs
|
2266 |
|
2267 |
@staticmethod
|
2268 |
def _reorder_cache(past_key_values, beam_idx):
|