Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from typing import Optional | |
from transformers import ( | |
AutoModel, | |
AutoTokenizer, | |
AutoConfig, | |
AutoModelForSequenceClassification | |
) | |
import os | |
from safetensors.torch import save_file | |
class SignalDetector(nn.Module): | |
def __init__(self, model_and_tokenizer_path) -> None: | |
super().__init__() | |
self.tokenizer = AutoTokenizer.from_pretrained(model_and_tokenizer_path) | |
self.signal_detector = AutoModelForSequenceClassification.from_pretrained(model_and_tokenizer_path) | |
self.signal_detector.eval() | |
self.signal_detector.cuda() | |
def predict(self, text: str) -> int: | |
input_ids = self.tokenizer.encode(text) | |
input_ids = torch.tensor([input_ids]).cuda() | |
outputs = self.signal_detector(input_ids) | |
return outputs[0].argmax().item() | |
class ST2ModelV2(nn.Module): | |
def __init__(self, args): | |
super(ST2ModelV2, self).__init__() | |
self.args = args | |
self.config = AutoConfig.from_pretrained("roberta-large") | |
self.model = AutoModel.from_pretrained("roberta-large") | |
self.tokenizer = AutoTokenizer.from_pretrained("roberta-large") | |
classifier_dropout = self.args.dropout | |
self.dropout = nn.Dropout(classifier_dropout) | |
self.classifier = nn.Linear(self.config.hidden_size, 6) | |
if args.mlp: | |
self.classifier = nn.Sequential( | |
nn.Linear(self.config.hidden_size, self.config.hidden_size), | |
nn.ReLU(), | |
nn.Linear(self.config.hidden_size, 6), | |
nn.Tanh(), | |
nn.Linear(6, 6), | |
) | |
if args.add_signal_bias: | |
self.signal_phrases_layer = nn.Parameter( | |
torch.normal( | |
mean=self.model.embeddings.word_embeddings.weight.data.mean(), | |
std=self.model.embeddings.word_embeddings.weight.data.std(), | |
size=(1, self.config.hidden_size), | |
) | |
) | |
if self.args.signal_classification and not self.args.pretrained_signal_detector: | |
self.signal_classifier = nn.Linear(self.config.hidden_size, 2) | |
def forward( | |
self, | |
input_ids: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
token_type_ids: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.Tensor] = None, | |
signal_bias_mask: Optional[torch.Tensor] = None, | |
head_mask: Optional[torch.Tensor] = None, | |
inputs_embeds: Optional[torch.Tensor] = None, | |
start_positions: Optional[torch.Tensor] = None, # [batch_size, 3] | |
end_positions: Optional[torch.Tensor] = None, # [batch_size, 3] | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
r""" | |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for position (index) of the start of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence | |
are not taken into account for computing the loss. | |
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for position (index) of the end of the labelled span for computing the token classification loss. | |
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence | |
are not taken into account for computing the loss. | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if signal_bias_mask is not None and not self.args.signal_bias_on_top_of_lm: | |
inputs_embeds = self.signal_phrases_bias(input_ids, signal_bias_mask) | |
outputs = self.model( | |
# input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
else: | |
outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
if signal_bias_mask is not None and self.args.signal_bias_on_top_of_lm: | |
sequence_output[signal_bias_mask == 1] += self.signal_phrases_layer | |
sequence_output = self.dropout(sequence_output) | |
logits = self.classifier(sequence_output) # [batch_size, max_seq_length, 6] | |
start_arg0_logits, end_arg0_logits, start_arg1_logits, end_arg1_logits, start_sig_logits, end_sig_logits = logits.split(1, dim=-1) | |
start_arg0_logits = start_arg0_logits.squeeze(-1).contiguous() | |
end_arg0_logits = end_arg0_logits.squeeze(-1).contiguous() | |
start_arg1_logits = start_arg1_logits.squeeze(-1).contiguous() | |
end_arg1_logits = end_arg1_logits.squeeze(-1).contiguous() | |
start_sig_logits = start_sig_logits.squeeze(-1).contiguous() | |
end_sig_logits = end_sig_logits.squeeze(-1).contiguous() | |
# start_arg0_logits -= (1 - attention_mask) * 1e4 | |
# end_arg0_logits -= (1 - attention_mask) * 1e4 | |
# start_arg1_logits -= (1 - attention_mask) * 1e4 | |
# end_arg1_logits -= (1 - attention_mask) * 1e4 | |
# start_arg0_logits[:, 0] = -1e4 | |
# end_arg0_logits[:, 0] = -1e4 | |
# start_arg1_logits[:, 0] = -1e4 | |
# end_arg1_logits[:, 0] = -1e4 | |
signal_classification_logits = None | |
if self.args.signal_classification and not self.args.pretrained_signal_detector: | |
signal_classification_logits = self.signal_classifier(sequence_output[:, 0, :]) | |
# start_logits = start_logits.squeeze(-1).contiguous() | |
# end_logits = end_logits.squeeze(-1).contiguous() | |
arg0_loss = None | |
arg1_loss = None | |
sig_loss = None | |
total_loss = None | |
signal_classification_loss = None | |
if start_positions is not None and end_positions is not None: | |
loss_fct = nn.CrossEntropyLoss() | |
start_arg0_loss = loss_fct(start_arg0_logits, start_positions[:, 0]) | |
end_arg0_loss = loss_fct(end_arg0_logits, end_positions[:, 0]) | |
arg0_loss = (start_arg0_loss + end_arg0_loss) / 2 | |
start_arg1_loss = loss_fct(start_arg1_logits, start_positions[:, 1]) | |
end_arg1_loss = loss_fct(end_arg1_logits, end_positions[:, 1]) | |
arg1_loss = (start_arg1_loss + end_arg1_loss) / 2 | |
# sig_loss = 0. | |
start_sig_loss = loss_fct(start_sig_logits, start_positions[:, 2]) | |
end_sig_loss = loss_fct(end_sig_logits, end_positions[:, 2]) | |
sig_loss = (start_sig_loss + end_sig_loss) / 2 | |
if sig_loss.isnan(): | |
sig_loss = 0. | |
if self.args.signal_classification and not self.args.pretrained_signal_detector: | |
signal_classification_labels = end_positions[:, 2] != -100 | |
signal_classification_loss = loss_fct(signal_classification_logits, signal_classification_labels.long()) | |
total_loss = (arg0_loss + arg1_loss + sig_loss + signal_classification_loss) / 4 | |
else: | |
total_loss = (arg0_loss + arg1_loss + sig_loss) / 3 | |
return { | |
'start_arg0_logits': start_arg0_logits, | |
'end_arg0_logits': end_arg0_logits, | |
'start_arg1_logits': start_arg1_logits, | |
'end_arg1_logits': end_arg1_logits, | |
'start_sig_logits': start_sig_logits, | |
'end_sig_logits': end_sig_logits, | |
'signal_classification_logits': signal_classification_logits, | |
'arg0_loss': arg0_loss, | |
'arg1_loss': arg1_loss, | |
'sig_loss': sig_loss, | |
'signal_classification_loss': signal_classification_loss, | |
'loss': total_loss, | |
} | |
""" | |
def save_pretrained(self, save_directory): | |
#Save model state dict as safetensor, configuration, and tokenizer files. | |
# Ensure the directory exists | |
os.makedirs(save_directory, exist_ok=True) | |
# Save model state dict as safetensor (use torch.save for PyTorch model) | |
model_path = os.path.join(save_directory, "model.safetensor") | |
save_file(self.state_dict(), model_path) | |
# Save config if available | |
config_save_path = os.path.join(save_directory, 'config.json') | |
self.config.to_json_file(config_save_path) | |
# Save tokenizer | |
if hasattr(self, 'tokenizer') and self.tokenizer is not None: | |
tokenizer_save_path = os.path.join(save_directory, 'tokenizer') | |
self.tokenizer.save_pretrained(tokenizer_save_path) | |
""" | |
def save_pretrained(self, save_directory): | |
""" | |
Save model state dict as safetensor, PyTorch .bin format, configuration, and tokenizer files. | |
""" | |
# Ensure the directory exists | |
os.makedirs(save_directory, exist_ok=True) | |
# Save model state dict as safetensor | |
model_path_safetensor = os.path.join(save_directory, "model.safetensors") | |
save_file(self.state_dict(), model_path_safetensor) # Save as .safetensors | |
# Save model state dict as PyTorch .bin (traditional format) | |
model_path_bin = os.path.join(save_directory, "pytorch_model.bin") | |
torch.save(self.state_dict(), model_path_bin) # Save as .bin using PyTorch's torch.save() | |
# Save config if available | |
config_save_path = os.path.join(save_directory, 'config.json') | |
self.config.to_json_file(config_save_path) | |
""" | |
# Save tokenizer if it exists | |
if hasattr(self, 'tokenizer') and self.tokenizer is not None: | |
tokenizer_save_path = os.path.join(save_directory, 'tokenizer') | |
self.tokenizer.save_pretrained(tokenizer_save_path) | |
""" | |
def signal_phrases_bias(self, input_ids, signal_bias_mask): | |
inputs_embeds = self.model.get_input_embeddings()(input_ids) | |
inputs_embeds[signal_bias_mask == 1] += self.signal_phrases_layer # self.signal_phrases_layer(inputs_embeds[signal_bias_mask == 1]) | |
return inputs_embeds | |
def position_selector( | |
self, | |
start_cause_logits, | |
start_effect_logits, | |
end_cause_logits, | |
end_effect_logits, | |
attention_mask, | |
word_ids, | |
): | |
# basic post processing (removing logits from [CLS], [SEP], [PAD]) | |
start_cause_logits -= (1 - attention_mask) * 1e4 | |
end_cause_logits -= (1 - attention_mask) * 1e4 | |
start_effect_logits -= (1 - attention_mask) * 1e4 | |
end_effect_logits -= (1 - attention_mask) * 1e4 | |
start_cause_logits[0] = -1e4 | |
end_cause_logits[0] = -1e4 | |
start_effect_logits[0] = -1e4 | |
end_effect_logits[0] = -1e4 | |
start_cause_logits[len(word_ids) - 1] = -1e4 | |
end_cause_logits[len(word_ids) - 1] = -1e4 | |
start_effect_logits[len(word_ids) - 1] = -1e4 | |
end_effect_logits[len(word_ids) - 1] = -1e4 | |
start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1)) | |
end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1)) | |
start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1)) | |
end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1)) | |
max_arg0_before_arg1 = None | |
for i in range(len(end_cause_logits)): | |
if attention_mask[i] == 0: | |
break | |
for j in range(i + 1, len(start_effect_logits)): | |
if attention_mask[j] == 0: | |
break | |
if max_arg0_before_arg1 is None: | |
max_arg0_before_arg1 = ((i, j), end_cause_logits[i] + start_effect_logits[j]) | |
else: | |
if end_cause_logits[i] + start_effect_logits[j] > max_arg0_before_arg1[1]: | |
max_arg0_before_arg1 = ((i, j), end_cause_logits[i] + start_effect_logits[j]) | |
max_arg0_after_arg1 = None | |
for i in range(len(end_effect_logits)): | |
if attention_mask[i] == 0: | |
break | |
for j in range(i + 1, len(start_cause_logits)): | |
if attention_mask[j] == 0: | |
break | |
if max_arg0_after_arg1 is None: | |
max_arg0_after_arg1 = ((i, j), start_cause_logits[j] + end_effect_logits[i]) | |
else: | |
if start_cause_logits[j] + end_effect_logits[i] > max_arg0_after_arg1[1]: | |
max_arg0_after_arg1 = ((i, j), start_cause_logits[j] + end_effect_logits[i]) | |
if max_arg0_before_arg1[1].item() > max_arg0_after_arg1[1].item(): | |
end_cause, start_effect = max_arg0_before_arg1[0] | |
start_cause_logits[end_cause + 1:] = -1e4 | |
start_cause = start_cause_logits.argmax().item() | |
end_effect_logits[:start_effect] = -1e4 | |
end_effect = end_effect_logits.argmax().item() | |
else: | |
end_effect, start_cause = max_arg0_after_arg1[0] | |
end_cause_logits[:start_cause] = -1e4 | |
end_cause = end_cause_logits.argmax().item() | |
start_effect_logits[end_effect + 1:] = -1e4 | |
start_effect = start_effect_logits.argmax().item() | |
return start_cause, end_cause, start_effect, end_effect | |
def beam_search_position_selector( | |
self, | |
start_cause_logits, | |
start_effect_logits, | |
end_cause_logits, | |
end_effect_logits, | |
attention_mask, | |
word_ids, | |
topk=5 | |
): | |
# basic post processing (removing logits from [CLS], [SEP], [PAD]) | |
start_cause_logits -= (1 - attention_mask) * 1e4 | |
end_cause_logits -= (1 - attention_mask) * 1e4 | |
start_effect_logits -= (1 - attention_mask) * 1e4 | |
end_effect_logits -= (1 - attention_mask) * 1e4 | |
start_cause_logits[0] = -1e4 | |
end_cause_logits[0] = -1e4 | |
start_effect_logits[0] = -1e4 | |
end_effect_logits[0] = -1e4 | |
start_cause_logits[len(word_ids) - 1] = -1e4 | |
end_cause_logits[len(word_ids) - 1] = -1e4 | |
start_effect_logits[len(word_ids) - 1] = -1e4 | |
end_effect_logits[len(word_ids) - 1] = -1e4 | |
start_cause_logits = torch.log(torch.softmax(start_cause_logits, dim=-1)) | |
end_cause_logits = torch.log(torch.softmax(end_cause_logits, dim=-1)) | |
start_effect_logits = torch.log(torch.softmax(start_effect_logits, dim=-1)) | |
end_effect_logits = torch.log(torch.softmax(end_effect_logits, dim=-1)) | |
scores = dict() | |
for i in range(len(end_cause_logits)): | |
if attention_mask[i] == 0: | |
break | |
for j in range(i + 1, len(start_effect_logits)): | |
if attention_mask[j] == 0: | |
break | |
scores[str((i, j, "before"))] = end_cause_logits[i].item() + start_effect_logits[j].item() | |
for i in range(len(end_effect_logits)): | |
if attention_mask[i] == 0: | |
break | |
for j in range(i + 1, len(start_cause_logits)): | |
if attention_mask[j] == 0: | |
break | |
scores[str((i, j, "after"))] = start_cause_logits[j].item() + end_effect_logits[i].item() | |
topk_scores = dict() | |
for i, (index, score) in enumerate(sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topk]): | |
if eval(index)[2] == 'before': | |
end_cause = eval(index)[0] | |
start_effect = eval(index)[1] | |
this_start_cause_logits = start_cause_logits.clone() | |
this_start_cause_logits[end_cause + 1:] = -1e9 | |
start_cause_values, start_cause_indices = this_start_cause_logits.topk(topk) | |
this_end_effect_logits = end_effect_logits.clone() | |
this_end_effect_logits[:start_effect] = -1e9 | |
end_effect_values, end_effect_indices = this_end_effect_logits.topk(topk) | |
for m in range(len(start_cause_values)): | |
for n in range(len(end_effect_values)): | |
topk_scores[str((start_cause_indices[m].item(), end_cause, start_effect, end_effect_indices[n].item()))] = score + start_cause_values[m].item() + end_effect_values[n].item() | |
elif eval(index)[2] == 'after': | |
start_cause = eval(index)[1] | |
end_effect = eval(index)[0] | |
this_end_cause_logits = end_cause_logits.clone() | |
this_end_cause_logits[:start_cause] = -1e9 | |
end_cause_values, end_cause_indices = this_end_cause_logits.topk(topk) | |
this_start_effect_logits = start_effect_logits.clone() | |
this_start_effect_logits[end_effect + 1:] = -1e9 | |
start_effect_values, start_effect_indices = this_start_effect_logits.topk(topk) | |
for m in range(len(end_cause_values)): | |
for n in range(len(start_effect_values)): | |
topk_scores[str((start_cause, end_cause_indices[m].item(), start_effect_indices[n].item(), end_effect))] = score + end_cause_values[m].item() + start_effect_values[n].item() | |
first, second = sorted(topk_scores.items(), key=lambda x: x[1], reverse=True)[:2] | |
return eval(first[0]), eval(second[0]), first[1], second[1], topk_scores |