NL2HLTL / finetune /realtime_run.py
tt-dart's picture
add train and run scripts
bacb17b
raw
history blame
1.89 kB
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
# Load peft config for pre-trained checkpoint etc.
peft_model_id="finetuned_model/results"
config = PeftConfig.from_pretrained(peft_model_id)
# load base LLM model and tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
# Load the Lora model
model = PeftModel.from_pretrained(model, peft_model_id, device_map={"":0})
model.eval()
print("Peft model loaded")
from datasets import load_dataset
from random import randrange
import evaluate
import numpy as np
import datasets
from tqdm import tqdm
# Metric
metric = evaluate.load("rouge")
def evaluate_peft_model(sample,max_target_length=50):
# generate summary
outputs = model.generate(input_ids=sample["input_ids"].unsqueeze(0).cuda(), do_sample=True, top_p=0.9, max_new_tokens=max_target_length)
prediction = tokenizer.decode(outputs[0].detach().cpu().numpy(), skip_special_tokens=True)
# decode eval sample
# Replace -100 in the labels as we can't decode them.
labels = np.where(sample['labels'] != -100, sample['labels'], tokenizer.pad_token_id)
labels = tokenizer.decode(labels, skip_special_tokens=True)
# Some simple post-processing
return prediction, labels
# load test dataset from distk
# test_dataset = load_from_disk("data/eval/").with_format("torch")
list_input = [{"natural": "go to P03 and then go to P04, remain in P04 until P05","raw_ltl":"0"}]
test_dataset = datasets.Dataset.from_list(list_input)
# run predictions
# this can take ~45 minutes
predictions, references = [] , []
for sample in tqdm(test_dataset):
p,l = evaluate_peft_model(sample)
print(p,l)
predictions.append(p)
references.append(l)