File size: 3,170 Bytes
acd7cf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from dataclasses import dataclass
from tqdm import tqdm
from graphgen.models.text.text_pair import TextPair


@dataclass
class RewardEvaluator:
    """
    Reward Model Evaluator.
    OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
    """
    reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
    max_length: int = 2560
    results: list[float] = None

    def __post_init__(self):
        import torch
        self.num_gpus = torch.cuda.device_count()

    @staticmethod
    def process_chunk(rank, pairs, reward_name, max_length, return_dict):
        import torch
        from transformers import AutoModelForSequenceClassification, AutoTokenizer
        device = f'cuda:{rank}'
        torch.cuda.set_device(rank)

        rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
        tokenizer = AutoTokenizer.from_pretrained(reward_name)
        rank_model.to(device)
        rank_model.eval()

        results = []
        with torch.no_grad():
            for pair in tqdm(pairs):
                inputs = tokenizer(
                    pair.question,
                    pair.answer,
                    return_tensors="pt",
                    max_length=max_length,
                    truncation=True
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                score = rank_model(**inputs).logits[0].item()
                results.append(score)

        return_dict[rank] = results

    def evaluate(self, pairs: list[TextPair]) -> list[float]:
        import torch.multiprocessing as mp
        chunk_size = len(pairs) // self.num_gpus
        chunks = []
        for i in range(self.num_gpus):
            start = i * chunk_size
            end = start + chunk_size
            if i == self.num_gpus - 1:
                end = len(pairs)
            chunks.append(pairs[start:end])

        # multi-process
        manager = mp.Manager()
        return_dict = manager.dict()
        processes = []

        for rank, chunk in enumerate(chunks):
            p = mp.Process(
                target=self.process_chunk,
                args=(rank, chunk, self.reward_name, self.max_length, return_dict)
            )
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        # 合并结果
        results = []
        for rank in range(len(chunks)):
            results.extend(return_dict[rank])

        for p in processes:
            if p.is_alive():
                p.terminate()
                p.join()

        return results

    def get_average_score(self, pairs: list[TextPair]) -> float:
        """
        Get the average score of a batch of texts.
        """
        results = self.evaluate(pairs)
        self.results = results
        return sum(self.results) / len(pairs)

    def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
        """
        Get the min and max score of a batch of texts.
        """
        if self.results is None:
            self.get_average_score(pairs)
        return min(self.results), max(self.results)