File size: 2,996 Bytes
9ff79dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from transformers import Trainer


class ContrastiveTrainer(Trainer):
    def __init__(self, loss_func, is_vision_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_func = loss_func
        self.is_vision_model = is_vision_model

    def compute_loss(self, model, inputs, return_outputs=False):
        query_outputs = model(input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"])
        if self.is_vision_model:
            if "doc_pixel_attention_mask" not in inputs:
                doc_outputs = model(
                    input_ids=inputs["doc_input_ids"],
                    attention_mask=inputs["doc_attention_mask"],
                    pixel_values=inputs["doc_pixel_values"],
                )
            else:
                doc_outputs = model(
                    input_ids=inputs["doc_input_ids"],
                    attention_mask=inputs["doc_attention_mask"],
                    pixel_values=inputs["doc_pixel_values"],
                    pixel_attention_mask=inputs["doc_pixel_attention_mask"],
                )
        else:
            doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])

        loss = self.loss_func(query_outputs, doc_outputs)
        return (loss, (query_outputs, doc_outputs)) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=True):
        """This function is used to generate predictions and return the loss for the given inputs."""
        if not prediction_loss_only:
            raise ValueError("prediction_step is only called with prediction_loss_only=True")

        with torch.no_grad():
            if self.is_vision_model:
                if "doc_pixel_attention_mask" not in inputs:
                    doc_outputs = model(
                        input_ids=inputs["doc_input_ids"],
                        attention_mask=inputs["doc_attention_mask"],
                        pixel_values=inputs["doc_pixel_values"],
                    )
                else:
                    doc_outputs = model(
                        input_ids=inputs["doc_input_ids"],
                        attention_mask=inputs["doc_attention_mask"],
                        pixel_values=inputs["doc_pixel_values"],
                        pixel_attention_mask=inputs["doc_pixel_attention_mask"],
                    )
                query_outputs = model(
                    input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
                )
            else:

                query_outputs = model(
                    input_ids=inputs["query_input_ids"], attention_mask=inputs["query_attention_mask"]
                )
                doc_outputs = model(input_ids=inputs["doc_input_ids"], attention_mask=inputs["doc_attention_mask"])

            loss = self.loss_func(query_outputs, doc_outputs)
            return loss, None, None