Update modeling_quiet.py
Browse files- modeling_quiet.py +80 -123
modeling_quiet.py
CHANGED
|
@@ -1022,9 +1022,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
| 1022 |
seq_len += 1
|
| 1023 |
|
| 1024 |
# Update the attention mask
|
| 1025 |
-
if attention_mask is None:
|
| 1026 |
-
attention_mask = torch.ones_like(input_ids)
|
| 1027 |
-
else:
|
| 1028 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
| 1029 |
|
| 1030 |
# Generate the continuation
|
|
@@ -1059,11 +1057,12 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
| 1059 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
| 1060 |
|
| 1061 |
# Append the generated token to the input sequence
|
| 1062 |
-
input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
|
| 1063 |
seq_len += 1
|
| 1064 |
|
| 1065 |
# Update the attention mask
|
| 1066 |
-
|
|
|
|
| 1067 |
|
| 1068 |
# Append the end thought token to the input sequence
|
| 1069 |
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
|
@@ -1071,7 +1070,8 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
| 1071 |
seq_len += 1
|
| 1072 |
|
| 1073 |
# Update the attention mask
|
| 1074 |
-
|
|
|
|
| 1075 |
|
| 1076 |
# Get the hidden states before and after the thought
|
| 1077 |
outputs_before = self.model(
|
|
@@ -1090,7 +1090,7 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
| 1090 |
# two new tokens: last continuation token and end thought token
|
| 1091 |
outputs_after = self.model(
|
| 1092 |
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
|
| 1093 |
-
attention_mask=torch.cat([attention_mask[:, -
|
| 1094 |
position_ids=position_ids,
|
| 1095 |
past_key_values=new_key_values,
|
| 1096 |
inputs_embeds=inputs_embeds,
|
|
@@ -1110,127 +1110,25 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
| 1110 |
# Apply the language model head to get the final logits
|
| 1111 |
logits = self.lm_head(mixed_hidden_states)
|
| 1112 |
return logits
|
| 1113 |
-
|
| 1114 |
-
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
|
| 1115 |
-
if attention_mask is None:
|
| 1116 |
-
attention_mask = torch.ones_like(input_ids)
|
| 1117 |
-
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
| 1118 |
-
|
| 1119 |
-
def _generate_no_beam_search(
|
| 1120 |
-
self,
|
| 1121 |
-
input_ids,
|
| 1122 |
-
cur_len,
|
| 1123 |
-
max_length,
|
| 1124 |
-
min_length,
|
| 1125 |
-
do_sample,
|
| 1126 |
-
temperature,
|
| 1127 |
-
top_k,
|
| 1128 |
-
top_p,
|
| 1129 |
-
repetition_penalty,
|
| 1130 |
-
no_repeat_ngram_size,
|
| 1131 |
-
bad_words_ids,
|
| 1132 |
-
pad_token_id,
|
| 1133 |
-
eos_token_id,
|
| 1134 |
-
batch_size,
|
| 1135 |
-
attention_mask,
|
| 1136 |
-
use_cache,
|
| 1137 |
-
model_kwargs,
|
| 1138 |
-
):
|
| 1139 |
-
finished_generating = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
|
| 1140 |
-
for cur_token_idx in range(max_length):
|
| 1141 |
-
# Sample the next token
|
| 1142 |
-
new_ids = self(
|
| 1143 |
-
input_ids[~finished_generating],
|
| 1144 |
-
attention_mask=attention_mask[~finished_generating]
|
| 1145 |
-
)['logits']
|
| 1146 |
-
# Mask out the start and end thought tokens so we don't accidentally sample them
|
| 1147 |
-
new_ids[:, :, self.tokenizer.vocab_size:] = -float("inf")
|
| 1148 |
-
for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
|
| 1149 |
-
# Find the index of the last token that is not padding
|
| 1150 |
-
base_answer_ids = input_ids[answer_idx]
|
| 1151 |
-
new_answer_ids = new_ids[list_idx]
|
| 1152 |
-
last_token_idx = (base_answer_ids != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()
|
| 1153 |
-
|
| 1154 |
-
new_ids_sampled = torch.multinomial(
|
| 1155 |
-
torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)
|
| 1156 |
-
# Assign the new id to the last token
|
| 1157 |
-
if last_token_idx + 1 >= len(base_answer_ids):
|
| 1158 |
-
# Add padding everywhere
|
| 1159 |
-
new_padding = torch.full((batch_size, 1), self.tokenizer.pad_token_id, dtype=torch.long,
|
| 1160 |
-
device=input_ids.device)
|
| 1161 |
-
input_ids = torch.cat([input_ids, new_padding], dim=-1)
|
| 1162 |
-
attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)
|
| 1163 |
-
attention_mask[answer_idx, last_token_idx + 1] = 1
|
| 1164 |
-
input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
|
| 1165 |
-
if new_ids_sampled == self.tokenizer.eos_token_id or new_ids_sampled == self.tokenizer.bos_token_id or new_ids_sampled == self.tokenizer.pad_token_id:
|
| 1166 |
-
finished_generating[answer_idx] = 1
|
| 1167 |
-
# Check if the end token is generated
|
| 1168 |
-
if new_ids_sampled == self.tokenizer.convert_tokens_to_ids("<|/assistant|>"):
|
| 1169 |
-
finished_generating[answer_idx] = 1
|
| 1170 |
-
if finished_generating.all():
|
| 1171 |
-
break
|
| 1172 |
-
return input_ids
|
| 1173 |
-
|
| 1174 |
@torch.no_grad()
|
| 1175 |
def generate(
|
| 1176 |
self,
|
| 1177 |
-
input_ids=
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
top_p=None,
|
| 1186 |
-
repetition_penalty=None,
|
| 1187 |
-
bad_words_ids=None,
|
| 1188 |
-
bos_token_id=None,
|
| 1189 |
-
pad_token_id=None,
|
| 1190 |
-
eos_token_id=None,
|
| 1191 |
-
length_penalty=None,
|
| 1192 |
-
no_repeat_ngram_size=None,
|
| 1193 |
-
num_return_sequences=None,
|
| 1194 |
-
attention_mask=None,
|
| 1195 |
-
decoder_start_token_id=None,
|
| 1196 |
-
use_cache=None,
|
| 1197 |
-
**model_kwargs,
|
| 1198 |
-
):
|
| 1199 |
-
max_length = max_length if max_length is not None else self.config.max_length
|
| 1200 |
-
min_length = min_length if min_length is not None else self.config.min_length
|
| 1201 |
-
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
| 1202 |
-
temperature = temperature if temperature is not None else self.config.temperature
|
| 1203 |
-
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
| 1204 |
-
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
| 1205 |
-
|
| 1206 |
-
# if input_ids is None:
|
| 1207 |
-
# raise ValueError("You have to specify either input_ids")
|
| 1208 |
|
| 1209 |
-
|
| 1210 |
-
|
|
|
|
| 1211 |
|
| 1212 |
-
|
| 1213 |
-
|
| 1214 |
-
|
| 1215 |
-
return self._generate_no_beam_search(
|
| 1216 |
-
input_ids,
|
| 1217 |
-
cur_len=cur_len,
|
| 1218 |
-
max_length=max_length,
|
| 1219 |
-
min_length=min_length,
|
| 1220 |
-
do_sample=do_sample,
|
| 1221 |
-
temperature=temperature,
|
| 1222 |
-
top_k=top_k,
|
| 1223 |
-
top_p=top_p,
|
| 1224 |
-
repetition_penalty=repetition_penalty,
|
| 1225 |
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 1226 |
-
bad_words_ids=bad_words_ids,
|
| 1227 |
-
pad_token_id=pad_token_id,
|
| 1228 |
-
eos_token_id=eos_token_id,
|
| 1229 |
-
batch_size=batch_size,
|
| 1230 |
-
attention_mask=attention_mask,
|
| 1231 |
-
use_cache=use_cache,
|
| 1232 |
-
model_kwargs=model_kwargs,
|
| 1233 |
-
)
|
| 1234 |
|
| 1235 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
| 1236 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
@@ -1971,6 +1869,65 @@ class QuietForCausalLM(QuietPreTrainedModel, GenerationMixin):
|
|
| 1971 |
hidden_states=outputs.hidden_states,
|
| 1972 |
attentions=outputs.attentions,
|
| 1973 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1974 |
|
| 1975 |
@staticmethod
|
| 1976 |
def _reorder_cache(past_key_values, beam_idx):
|
|
|
|
| 1022 |
seq_len += 1
|
| 1023 |
|
| 1024 |
# Update the attention mask
|
| 1025 |
+
if attention_mask is not None:
|
|
|
|
|
|
|
| 1026 |
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
| 1027 |
|
| 1028 |
# Generate the continuation
|
|
|
|
| 1057 |
next_token_id = torch.argmax(next_token_logits, dim=-1)
|
| 1058 |
|
| 1059 |
# Append the generated token to the input sequence
|
| 1060 |
+
# input_ids = torch.cat([input_ids, next_token_id.unsqueeze(-1).to(input_ids.device)], dim=-1)
|
| 1061 |
seq_len += 1
|
| 1062 |
|
| 1063 |
# Update the attention mask
|
| 1064 |
+
if attention_mask is not None:
|
| 1065 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
| 1066 |
|
| 1067 |
# Append the end thought token to the input sequence
|
| 1068 |
end_thought_token_id = self.tokenizer.convert_tokens_to_ids("<|endthought|>")
|
|
|
|
| 1070 |
seq_len += 1
|
| 1071 |
|
| 1072 |
# Update the attention mask
|
| 1073 |
+
if attention_mask is not None:
|
| 1074 |
+
attention_mask = torch.cat([attention_mask, torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1)
|
| 1075 |
|
| 1076 |
# Get the hidden states before and after the thought
|
| 1077 |
outputs_before = self.model(
|
|
|
|
| 1090 |
# two new tokens: last continuation token and end thought token
|
| 1091 |
outputs_after = self.model(
|
| 1092 |
input_ids=torch.cat([next_token_id.unsqueeze(-1).to(input_ids.device), torch.tensor([[end_thought_token_id]] * batch_size).to(input_ids.device)], dim=-1),
|
| 1093 |
+
attention_mask=torch.cat([attention_mask[:, -1:], torch.ones((batch_size, 1)).to(attention_mask.device)], dim=-1),
|
| 1094 |
position_ids=position_ids,
|
| 1095 |
past_key_values=new_key_values,
|
| 1096 |
inputs_embeds=inputs_embeds,
|
|
|
|
| 1110 |
# Apply the language model head to get the final logits
|
| 1111 |
logits = self.lm_head(mixed_hidden_states)
|
| 1112 |
return logits
|
| 1113 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1114 |
@torch.no_grad()
|
| 1115 |
def generate(
|
| 1116 |
self,
|
| 1117 |
+
input_ids: torch.LongTensor = torch.LongTensor(),
|
| 1118 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1119 |
+
max_new_tokens: Optional[int] = None,
|
| 1120 |
+
temperature: float = 1.1,
|
| 1121 |
+
**kwargs,
|
| 1122 |
+
):
|
| 1123 |
+
if isinstance(input_ids, str):
|
| 1124 |
+
input_ids = self.tokenizer(input_ids, return_tensors="pt").input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1125 |
|
| 1126 |
+
if attention_mask is None:
|
| 1127 |
+
# Create a default attention mask if not provided
|
| 1128 |
+
attention_mask = torch.ones_like(input_ids)
|
| 1129 |
|
| 1130 |
+
from .generate import custom_generate
|
| 1131 |
+
return custom_generate(self, input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, temperature=temperature, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1132 |
|
| 1133 |
@add_start_docstrings_to_model_forward(QUIET_INPUTS_DOCSTRING)
|
| 1134 |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
|
|
|
|
| 1869 |
hidden_states=outputs.hidden_states,
|
| 1870 |
attentions=outputs.attentions,
|
| 1871 |
)
|
| 1872 |
+
|
| 1873 |
+
|
| 1874 |
+
|
| 1875 |
+
def prepare_inputs_for_generation(
|
| 1876 |
+
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1877 |
+
):
|
| 1878 |
+
# Omit tokens covered by past_key_values
|
| 1879 |
+
if past_key_values is not None:
|
| 1880 |
+
if isinstance(past_key_values, Cache):
|
| 1881 |
+
cache_length = past_key_values.get_seq_length()
|
| 1882 |
+
past_length = past_key_values.seen_tokens
|
| 1883 |
+
max_cache_length = past_key_values.get_max_length()
|
| 1884 |
+
else:
|
| 1885 |
+
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1886 |
+
max_cache_length = None
|
| 1887 |
+
|
| 1888 |
+
# Keep only the unprocessed tokens:
|
| 1889 |
+
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1890 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing inputs_embeds as
|
| 1891 |
+
# input)
|
| 1892 |
+
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 1893 |
+
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1894 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
|
| 1895 |
+
# input_ids based on the past_length.
|
| 1896 |
+
elif past_length < input_ids.shape[1]:
|
| 1897 |
+
input_ids = input_ids[:, past_length:]
|
| 1898 |
+
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
| 1899 |
+
|
| 1900 |
+
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
|
| 1901 |
+
if (
|
| 1902 |
+
max_cache_length is not None
|
| 1903 |
+
and attention_mask is not None
|
| 1904 |
+
and cache_length + input_ids.shape[1] > max_cache_length
|
| 1905 |
+
):
|
| 1906 |
+
attention_mask = attention_mask[:, -max_cache_length:]
|
| 1907 |
+
|
| 1908 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1909 |
+
if attention_mask is not None and position_ids is None:
|
| 1910 |
+
# create position_ids on the fly for batch generation
|
| 1911 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1912 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1913 |
+
if past_key_values:
|
| 1914 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1915 |
+
|
| 1916 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1917 |
+
if inputs_embeds is not None and past_key_values is None:
|
| 1918 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
| 1919 |
+
else:
|
| 1920 |
+
model_inputs = {"input_ids": input_ids}
|
| 1921 |
+
|
| 1922 |
+
model_inputs.update(
|
| 1923 |
+
{
|
| 1924 |
+
"position_ids": position_ids,
|
| 1925 |
+
"past_key_values": past_key_values,
|
| 1926 |
+
"use_cache": kwargs.get("use_cache"),
|
| 1927 |
+
"attention_mask": attention_mask,
|
| 1928 |
+
}
|
| 1929 |
+
)
|
| 1930 |
+
return model_inputs
|
| 1931 |
|
| 1932 |
@staticmethod
|
| 1933 |
def _reorder_cache(past_key_values, beam_idx):
|