UW-Madison-Lee-Lab's picture
Update README.md
1cba450 verified
metadata
base_model: UW-Madison-Lee-Lab/Llama-PRM800K
library_name: peft
license: llama3.1
tags:
  - generated_from_trainer
model-index:
  - name: VersaPRM-Math-Subset
    results: []

VersaPRM-Math-Subset

This model is a fine-tuned version of UW-Madison-Lee-Lab/Llama-PRM800K on the math category subset of UW-Madison-Lee-Lab/MMLU-Pro-CoT-Train-Labeled.

Get rewards

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def get_tokenizer(model_id):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token  
    tokenizer.padding_side = 'left' 
    tokenizer.truncation_side = 'left'
    return tokenizer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = get_tokenizer('UW-Madison-Lee-Lab/VersaPRM-Math-Subset')
model = AutoModelForCausalLM.from_pretrained('UW-Madison-Lee-Lab/VersaPRM-Math-Subset')
candidate_tokens = [12, 10]
model.to(device)

question = 'Question: In Python 3, which of the following function convert a string to an int in python?\nA. short(x)\nB. float(x)\nC. integer(x [,base])\nD. double(x)\nE. int(x [,base])\nF. long(x [,base] )\nG. num(x)\nH. str(x)\nI. char(x)\nJ. digit(x [,base])'
solution = ["To convert a string to an integer in Python 3, we use the built-in function int().",
            "The int() function takes two arguments: the string to be converted and an optional base (default is 10, which is for decimal).",
            "For example: int(\"123\", 10) converts the string \"123\" to the integer 123.",
            "Looking at the options, we can see that the correct function is option E: int(x [,base]).",
            "The answer is (E)."]
input_text = question + ' \n\n' + ' \n\n\n\n'.join(solution) + ' \n\n\n\n' # solution steps are separated by ' \n\n\n\n'
input_id = torch.tensor([tokenizer.encode(input_text)]).to(device)

with torch.no_grad():
    logits = model(input_id).logits[:,:,candidate_tokens]
    scores = logits.softmax(dim=-1)[:,:,1] 
    step_scores = scores[input_id == 23535]
    step_probs  = step_scores.tolist()