GraphGen / graphgen /models /evaluate /reward_evaluator.py
chenzihong-gavin
init
acd7cf4
from dataclasses import dataclass
from tqdm import tqdm
from graphgen.models.text.text_pair import TextPair
@dataclass
class RewardEvaluator:
"""
Reward Model Evaluator.
OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好
"""
reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2"
max_length: int = 2560
results: list[float] = None
def __post_init__(self):
import torch
self.num_gpus = torch.cuda.device_count()
@staticmethod
def process_chunk(rank, pairs, reward_name, max_length, return_dict):
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
device = f'cuda:{rank}'
torch.cuda.set_device(rank)
rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name)
tokenizer = AutoTokenizer.from_pretrained(reward_name)
rank_model.to(device)
rank_model.eval()
results = []
with torch.no_grad():
for pair in tqdm(pairs):
inputs = tokenizer(
pair.question,
pair.answer,
return_tensors="pt",
max_length=max_length,
truncation=True
)
inputs = {k: v.to(device) for k, v in inputs.items()}
score = rank_model(**inputs).logits[0].item()
results.append(score)
return_dict[rank] = results
def evaluate(self, pairs: list[TextPair]) -> list[float]:
import torch.multiprocessing as mp
chunk_size = len(pairs) // self.num_gpus
chunks = []
for i in range(self.num_gpus):
start = i * chunk_size
end = start + chunk_size
if i == self.num_gpus - 1:
end = len(pairs)
chunks.append(pairs[start:end])
# multi-process
manager = mp.Manager()
return_dict = manager.dict()
processes = []
for rank, chunk in enumerate(chunks):
p = mp.Process(
target=self.process_chunk,
args=(rank, chunk, self.reward_name, self.max_length, return_dict)
)
p.start()
processes.append(p)
for p in processes:
p.join()
# 合并结果
results = []
for rank in range(len(chunks)):
results.extend(return_dict[rank])
for p in processes:
if p.is_alive():
p.terminate()
p.join()
return results
def get_average_score(self, pairs: list[TextPair]) -> float:
"""
Get the average score of a batch of texts.
"""
results = self.evaluate(pairs)
self.results = results
return sum(self.results) / len(pairs)
def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
"""
Get the min and max score of a batch of texts.
"""
if self.results is None:
self.get_average_score(pairs)
return min(self.results), max(self.results)