File size: 5,321 Bytes
438b415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import UMT5Model
from .configuration_rankingprompter import RankingPrompterConfig


@dataclass
class RankingPrompterForPreTrainingOutput:
    loss: torch.FloatTensor = None
    logits: torch.FloatTensor = None


class RankingPrompterForPreTraining(UMT5Model):
    config_class = RankingPrompterConfig

    _tied_weights_keys = [
        "encoder.embed_tokens.weight",
        "decoder.embed_tokens.weight",
    ]

    def __init__(self, config):
        # encoder, decoder and shared are from UMT5Model
        super().__init__(config)

        # add ranking head
        self.ranking_head = nn.Linear(config.d_model, 1)

        # Initialize weights and apply final processing
        self.post_init()

        # ctx for mixed precision training
        self.ctx = nullcontext()

    def enable_amp_ctx(self, device_type="cuda", dtype=torch.bfloat16):
        self.ctx = torch.amp.autocast(device_type=device_type, dtype=dtype)

    def disable_amp_ctx(self):
        self.ctx = nullcontext()

    def forward(
        self,
        document_input_ids: Optional[torch.LongTensor] = None,
        document_attention_mask: Optional[torch.FloatTensor] = None,
        question_input_ids: Optional[torch.LongTensor] = None,
        question_attention_mask: Optional[torch.BoolTensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.FloatTensor], RankingPrompterForPreTrainingOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`

        Returns:

        ```"""
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
        # document_input_ids: [batch_size, num_doc, doc_seq_len]
        batch_size, num_doc, doc_seq_len = document_input_ids.shape
        #
        document_input_ids = document_input_ids.view(-1, doc_seq_len)
        # to [batch_size * num_doc, doc_seq_len]
        document_attention_mask = document_attention_mask.view(-1, doc_seq_len)

        # Convert encoder inputs in embeddings if needed
        with self.ctx:
            encoder_outputs = self.encoder(
                input_ids=document_input_ids,
                attention_mask=document_attention_mask,
                return_dict=return_dict,
            )

        document_embeds = encoder_outputs[0]

        # repeat question inputs for each document
        # question_input_ids: [batch_size, question_seq_len]
        question_seq_len = question_input_ids.shape[1]
        question_input_ids = (
            question_input_ids.unsqueeze(1)
            .expand(-1, num_doc, -1)
            .reshape(-1, question_seq_len)
        )  # [batch_size * num_doc, question_seq_len]
        question_attention_mask = (
            question_attention_mask.unsqueeze(1)
            .expand(-1, num_doc, -1)
            .reshape(-1, question_seq_len)
        )  # [batch_size * num_doc, question_seq_len]

        # Decode
        with self.ctx:
            decoder_outputs = self.decoder(
                input_ids=question_input_ids,
                attention_mask=question_attention_mask,
                past_key_values=past_key_values,
                encoder_hidden_states=document_embeds,
                encoder_attention_mask=document_attention_mask,
                use_cache=use_cache,
                return_dict=return_dict,
            )
        # [batch_size * num_doc, soft_prompt_len + question_seq_len, hidden_size]
        sequence_output = decoder_outputs[0]
        # [batch_size * num_doc, soft_prompt_len, hidden_size]
        question_seq_len = sequence_output.size(1)
        # [batch_size, num_doc, soft_prompt_len, hidden_size]
        soft_prompt_output = sequence_output.view(
            batch_size, num_doc, question_seq_len, -1
        )

        # [batch_size, num_doc, self.num_soft_prompt_tokens, hidden_size] -> [batch_size, num_doc, hidden_size]
        ranking_logits = self.ranking_head(soft_prompt_output.mean(dim=2))

        # rank loss
        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            ranking_logits = ranking_logits.view(batch_size, num_doc)
            loss = loss_fct(ranking_logits, labels)

        if not return_dict:
            output = (ranking_logits,) + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return RankingPrompterForPreTrainingOutput(
            loss=loss,
            logits=ranking_logits
        )