File size: 2,361 Bytes
6825733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18f91d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
---
base_model: meta-llama/Llama-3.1-8B-Instruct
library_name: peft
license: llama3.1
tags:
- generated_from_trainer
model-index:
- name: VersaPRM-Base-8B
  results: []
---

<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->

# VersaPRM-Base-8B

This model is a fine-tuned version of [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) on [UW-Madison-Lee-Lab/MMLU-Pro-CoT-Train-Labeled](https://huggingface.co/datasets/UW-Madison-Lee-Lab/MMLU-Pro-CoT-Train-Labeled).

## Get rewards
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def get_tokenizer(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token  
    tokenizer.padding_side = 'left' 
    tokenizer.truncation_side = 'left'
    return tokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = get_tokenizer('UW-Madison-Lee-Lab/VersaPRM-Base-8B')
model = AutoModelForCausalLM.from_pretrained('UW-Madison-Lee-Lab/VersaPRM-Base-8B')
candidate_tokens = [12, 10]
model.to(device)

question = 'Question: In Python 3, which of the following function convert a string to an int in python?\nA. short(x)\nB. float(x)\nC. integer(x [,base])\nD. double(x)\nE. int(x [,base])\nF. long(x [,base] )\nG. num(x)\nH. str(x)\nI. char(x)\nJ. digit(x [,base])'
solution = ["To convert a string to an integer in Python 3, we use the built-in function int().",
            "The int() function takes two arguments: the string to be converted and an optional base (default is 10, which is for decimal).",
            "For example: int(\"123\", 10) converts the string \"123\" to the integer 123.",
            "Looking at the options, we can see that the correct function is option E: int(x [,base]).",
            "The answer is (E)."]
input_text = question + ' \n\n' + ' \n\n\n\n'.join(solution) + ' \n\n\n\n' # solution steps are separated by ' \n\n\n\n'
input_id = torch.tensor([tokenizer.encode(input_text)]).to(device)

with torch.no_grad():
    logits = model(input_id).logits[:,:,candidate_tokens]
    scores = logits.softmax(dim=-1)[:,:,1] 
    step_scores = scores[input_id == 23535]
    step_probs  = step_scores.tolist()
```