Min-Li commited on
Commit
40676ad
·
verified ·
1 Parent(s): 28c1b48

Upload 3 files

Browse files
config.json CHANGED
@@ -1,7 +1,10 @@
1
  {
2
  "_name_or_path": "RLHFlow/Decision-Tree-Reward-Llama-3.1-8B",
 
 
 
3
  "architectures": [
4
- "LlamaForSequenceClassification"
5
  ],
6
  "attention_bias": false,
7
  "attention_dropout": 0.0,
 
1
  {
2
  "_name_or_path": "RLHFlow/Decision-Tree-Reward-Llama-3.1-8B",
3
+ "auto_map": {
4
+ "AutoModelForSequenceClassification": "modeling_decision_tree_reward_model.LlamaForDecisionTreeRewardModel"
5
+ },
6
  "architectures": [
7
+ "LlamaForDecisionTreeRewardModel"
8
  ],
9
  "attention_bias": false,
10
  "attention_dropout": 0.0,
decision_tree.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5433ef4e775535b16490d8b2e9693a4f46b2f637b0749a089f097c94159814c5
3
  size 2388
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83f0139429fff38e775af9a281ba5600a46ff852967f6c310667e61710b5bf40
3
  size 2388
modeling_decision_tree_reward_model.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers.models.llama.modeling_llama import LlamaForSequenceClassification
4
+ from sklearn.tree import DecisionTreeClassifier
5
+ import os
6
+ import pickle
7
+ import json
8
+ from huggingface_hub import hf_hub_download
9
+ from typing import List, Dict, Union
10
+ import numpy as np
11
+
12
+ def convert_to_chat_format(prompt, response=None):
13
+ if "<extra_id_1>" in prompt:
14
+ """
15
+ Handling HelpSteer2 prompts which may contain multi-turn conversations with the special token <extra_id_1>
16
+ """
17
+ turns = prompt.split("<extra_id_1>")
18
+ conversation = []
19
+ conversation.append({
20
+ "role": "user",
21
+ "content": turns[0]
22
+ })
23
+
24
+ for i in range(1, len(turns)):
25
+ parts = turns[i].split("\n", 1)
26
+ role = parts[0]
27
+ content = parts[1]
28
+ conversation.append({
29
+ "role": "assistant" if role == "Assistant" else "user",
30
+ "content": content
31
+ })
32
+ else:
33
+ conversation = [{"role": "user", "content": prompt}]
34
+ if response is not None:
35
+ conversation.append({"role": "assistant", "content": response})
36
+ return conversation
37
+
38
+ def process_conversation(conversation):
39
+ for message in conversation:
40
+ message["content"] = message["content"].rstrip('\n')
41
+ return conversation
42
+
43
+ class LlamaForDecisionTreeRewardModel(LlamaForSequenceClassification):
44
+ def __init__(self, config):
45
+ super().__init__(config)
46
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=True)
47
+ # Initialize the decision tree
48
+ self.tree = None
49
+ # Define the default attributes (from HelpSteer2)
50
+ self.attributes = ['helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity']
51
+ print("Initialized LlamaForDecisionTreeRewardModel")
52
+
53
+ def load_decision_tree(self, repo_id, filename="decision_tree.pkl"):
54
+ # Load the tree from the model's directory
55
+ with open(hf_hub_download(repo_id=repo_id, filename=filename), "rb") as f:
56
+ self.tree = pickle.load(f)
57
+ assert isinstance(self.tree, DecisionTreeClassifier), f"The tree is not a DecisionTreeClassifier. It is a {type(self.tree)}"
58
+ with open(hf_hub_download(repo_id=repo_id, filename="config.json"), "r") as f:
59
+ config = json.load(f)
60
+ label2id_map = config["label2id"]
61
+ # Sort labels and ids by ids
62
+ labels, ids = zip(*sorted(label2id_map.items(), key=lambda x: x[1]))
63
+ labels = list(labels)
64
+ self.attributes = labels
65
+
66
+ @torch.no_grad()
67
+ def compare(self, prompt: Union[str, List[Dict[str, str]]], response_1: str, response_2: str, tokenizer, device):
68
+ """
69
+ Compare two inputs and return the difference in scores
70
+ """
71
+ assert self.tree is not None, "The decision tree is not loaded. Please call load_decision_tree(repo_id, filename) first."
72
+ if isinstance(prompt, str):
73
+ conversation = convert_to_chat_format(prompt)
74
+ elif isinstance(prompt, list):
75
+ conversation = prompt
76
+ else:
77
+ raise ValueError(f"The prompt must be a string or a list of dictionaries, but got {type(prompt)}")
78
+ assert isinstance(conversation, list), "The conversation must be a list of dictionaries"
79
+ assert len(conversation) >= 1, "The conversation must have at least one message (as prompt)"
80
+ assert conversation[-1]["role"] == "user", "The last message in the conversation must be from the user"
81
+ conversation_1 = conversation + [{"role": "assistant", "content": response_1}]
82
+ conversation_2 = conversation + [{"role": "assistant", "content": response_2}]
83
+ conversation_1 = process_conversation(conversation_1)
84
+ conversation_2 = process_conversation(conversation_2)
85
+
86
+ conv_tokenized_1 = tokenizer.apply_chat_template(conversation_1, tokenize=True, return_tensors="pt").to(device)
87
+ conv_tokenized_2 = tokenizer.apply_chat_template(conversation_2, tokenize=True, return_tensors="pt").to(device)
88
+ embedding_1 = self.forward(conv_tokenized_1, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy()
89
+ embedding_2 = self.forward(conv_tokenized_2, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy()
90
+ weight = self.score.weight.float().cpu().numpy()
91
+ bias = self.score.bias.float().cpu().numpy()
92
+ rewards_1 = embedding_1 @ weight.T + bias
93
+ rewards_2 = embedding_2 @ weight.T + bias
94
+ rewards_diff = rewards_2 - rewards_1
95
+ return {
96
+ "preference": self.tree.predict(rewards_diff)[0], "rewards": np.stack([rewards_1, rewards_2]),
97
+ "attributes": self.attributes}