NL2HLTL / NL2HLTLTranslator /T5_XXL /t5_realtime_evaluate.py
tt-dart's picture
update readme
d834d9d
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import sys
# sys.path.append("..")
# sys.path.append("../../")
from ... import utils as util
# Load peft config for pre-trained checkpoint etc.
class T5XXL_NL2TL_translator():
def __init__(self) -> None:
# exp_name="_mid_ascii"
peft_model_id="model_weight/tf-ltl_eng_test_mid_ascii_gptAuged"
self.max_target_length=128
self.config = PeftConfig.from_pretrained(peft_model_id)
# load base LLM model and tokenizer
self.model = AutoModelForSeq2SeqLM.from_pretrained(self.config.base_model_name_or_path, load_in_8bit=True, device_map="auto")
self.tokenizer = AutoTokenizer.from_pretrained(self.config.base_model_name_or_path, device_map="auto")
# Load the Lora model
self.model = PeftModel.from_pretrained(self.model, peft_model_id, device_map="auto")
self.model.eval()
print("Peft model loaded")
pass
def translate(self,input:str=""):
input_prompt= "Generate LTL: " + input
replace=util.Task2Preplacer()
input_prompt=replace.reTask2P(input_prompt)
# print(predicter( replace.reTask2P(input_prompt)))
print(input_prompt)
input_ids = self.tokenizer(input_prompt, return_tensors="pt", truncation=True).input_ids.cuda()
outputs = self.model.generate(input_ids=input_ids, max_new_tokens=self.max_target_length, do_sample=True, top_p=0.9)
output_txt= self.tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0]
print(output_txt)
return replace.reP2Task(output_txt)
if __name__=="__main__":
test_prompts=[
"Task_1.1.1 must precede Task_1.1.2, which in turn should precede Task_1.1.3, ensuring that arranging fruits happens before preparing vegetables and prepping eggs and meats is done last.",
"Task_1.1 must be completed before Task_1.2 starts, and Task_1.2 must be completed before Task_1.3 starts."
]
translater=T5XXL_NL2TL_translator()
for ret in test_prompts:
print(translater.translate(ret))
flag=True
while flag:
lines=[""]
try:
lines.append(input())
while True:
lines.append(input())
except:
pass
ret ="".join(lines)
print(ret)
if ret=="":
flag=False
print(translater.translate(ret))