Di Zhang
Update README.md
7cbef74 verified
|
raw
history blame
2.12 kB
metadata
library_name: transformers
license: other
base_model: SimpleBerry/LLaMA-O1-Base-1127
tags:
  - llama-factory
  - full
  - generated_from_trainer
model-index:
  - name: SimpleBerry/LLaMA-O1-Supervised-1129
    results: []

SimpleBerry/LLaMA-O1-Supervised-1129

This model is a fine-tuned version of SimpleBerry/LLaMA-O1-Base-1127 on the SimpleBerry/OpenLongCoT-SFT dataset.

Inference

import json
import datasets
import torch
import random
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM



tokenizer = AutoTokenizer.from_pretrained("/mnt/hwfile/ai4chem/CKPT/longcot_sft_llama3.1_ZD_11_29_1/")
model = AutoModelForCausalLM.from_pretrained("/mnt/hwfile/ai4chem/CKPT/longcot_sft_llama3.1_ZD_11_29_1/",device_map='auto')



template = "<start_of_father_id>-1<end_of_father_id><start_of_local_id>0<end_of_local_id><start_of_thought><problem>{content}<end_of_thought><start_of_rating><positive_rating><end_of_rating>\n<start_of_father_id>0<end_of_father_id><start_of_local_id>1<end_of_local_id><start_of_thought><expansion>"

def llama_o1_template(data):
    query = data['query']
    text = template.format(content=query)
    return text


def batch_predict(input_texts):
    input_texts = [input_text.replace('<|end_of_text|>','') for input_text in input_texts]
    inputs = tokenizer(input_texts, return_tensors="pt").to(model.device)
    responses = model.generate(**inputs, max_new_tokens=1024)
    response_texts = tokenizer.batch_decode(responses, skip_special_tokens=False)
    # assitant_responses = [item[len(input_texts[i]):] for i,item in enumerate(response_texts)]
    assitant_responses = [item for i,item in enumerate(response_texts)]
    return assitant_responses


i = input()
input_texts = [llama_o1_template(i)]
assitant_responses = batch_predict(input_texts)
print(assitant_responses)