|
|
|
|
|
import sys,os |
|
from datasets import load_dataset |
|
import torch |
|
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments |
|
|
|
|
|
|
|
|
|
|
|
from NL2HLTLTranslator.utils.util import Task2Preplacer |
|
from NL2HLTLTranslator.utils.util import LTLChecker |
|
import re |
|
from datasets import concatenate_datasets |
|
import numpy as np |
|
from peft import AutoPeftModelForCausalLM |
|
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' |
|
|
|
|
|
|
|
|
|
class Mistral_NL2TL_translator(): |
|
def __init__(self, |
|
output_dir = os.path.join(os.path.dirname(__file__),'../../'), |
|
tuned_model_name="mistral7b_quat8", |
|
|
|
quat=True, |
|
replacer=Task2Preplacer) -> None: |
|
|
|
self.device_map="auto" |
|
self.model_dir = os.path.join(output_dir, tuned_model_name) |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.bnb_config = BitsAndBytesConfig( |
|
load_in_4bit = True, |
|
bnb_4bit_use_double_quant = False, |
|
bnb_4bit_quant_type = 'nf4', |
|
bnb_4bit_compute_dtype = getattr(torch, "float16") |
|
) |
|
self.bnb_config = BitsAndBytesConfig( |
|
load_in_8bit = True, |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
self.model_dir, |
|
from_tf=bool(".ckpt" in self.model_dir), |
|
quantization_config=self.bnb_config, |
|
device_map=self.device_map, |
|
trust_remote_code=True, |
|
use_auth_token=True |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) |
|
|
|
|
|
|
|
|
|
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
self.tokenizer.padding_side = 'right' |
|
print(self.tokenizer.eos_token_id) |
|
|
|
print(self.tokenizer.bos_token_id) |
|
|
|
|
|
|
|
print("NL2TL model loaded") |
|
|
|
self.replacer=replacer |
|
self.ltlChecker=LTLChecker() |
|
pass |
|
|
|
|
|
|
|
def evaluate_model(self, input_text): |
|
self.pattern=re.compile("linear temproal logic is ([\S ]*).") |
|
messages=[ |
|
{"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical way, and then translate into linear temproal logic, pay specific attention to brackets '()', natural language task: {}".format(input_text.strip())}, |
|
] |
|
|
|
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device) |
|
outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id) |
|
|
|
p=self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print('model output:',p) |
|
transLTL=self.pattern.findall(p)[0] |
|
if transLTL[-1]=='.': |
|
transLTL=transLTL[:-1].strip() |
|
else: |
|
transLTL=transLTL.strip() |
|
transLTL=self.ltlChecker.right_barkets_remover(transLTL) |
|
print('transLTL:\n',transLTL) |
|
return transLTL |
|
def evaluate_model2(self, input_text): |
|
self.pattern=re.compile("LTL is ([\S ]*).") |
|
messages=[ |
|
{"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, the natural language task is {}".format(input_text.strip())}, |
|
] |
|
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device) |
|
outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id) |
|
p=self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print('---model output 1:\n',p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transLTL=self.pattern.findall(p)[0] |
|
if transLTL[-1]=='.': |
|
transLTL=transLTL[:-1].strip() |
|
else: |
|
transLTL=transLTL.strip() |
|
transLTL=self.ltlChecker.right_barkets_remover(transLTL) |
|
print('transLTL:\n',transLTL) |
|
return transLTL |
|
def evaluate_model3(self, input_text): |
|
|
|
|
|
self.pattern=re.compile("LTL is ([^\.]*)\.") |
|
messages=[ |
|
{"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, please pay specific attention to logic grammar, the natural language task is {}".format(input_text.strip())}, |
|
] |
|
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device) |
|
outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id) |
|
p=self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
print('---model output 1:\n',p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transLTL=self.pattern.findall(p) |
|
if len(transLTL)==0: |
|
return False |
|
transLTL=transLTL[0] |
|
if transLTL[-1]=='.': |
|
transLTL=transLTL[:-1].strip() |
|
else: |
|
transLTL=transLTL.strip() |
|
transLTL=self.ltlChecker.right_barkets_remover(transLTL) |
|
print('transLTL:\n',transLTL) |
|
return transLTL |
|
def translate(self,input_prompt:str=""): |
|
print('input_prompt:\n',input_prompt) |
|
replacer=self.replacer() |
|
input_prompt=replacer.reTask2P(input_prompt) |
|
|
|
|
|
|
|
|
|
|
|
flag_check_false_count=0 |
|
flag_check=False |
|
while not flag_check and flag_check_false_count<10: |
|
flag_check_false_count+=1 |
|
flag_check=True |
|
transLTL=self.evaluate_model3(input_prompt) |
|
transLTL=transLTL.replace('Or','And') |
|
transLTL=transLTL.replace('Globally','Finally') |
|
if isinstance(transLTL,bool): |
|
flag_check=False |
|
elif not self.ltlChecker.AP_CorrCheck(input_prompt,transLTL): |
|
print('AP_CorrCheck false') |
|
flag_check=False |
|
elif not self.ltlChecker.brackets_Check(transLTL): |
|
print('brackets_Check false') |
|
flag_check=False |
|
|
|
return replacer.reP2Task(transLTL) |
|
|
|
|
|
if __name__=="__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class p2preplacer(): |
|
def reTask2P(self,input): |
|
return input |
|
def reP2Task(self,input): |
|
return input |
|
translater=Mistral_NL2TL_translator(replacer=p2preplacer) |
|
import evaluate |
|
import numpy as np |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
metric = evaluate.load("rouge") |
|
datapath='path/to/NL2TL-dataset/collect2' |
|
tokenized_dataset = load_dataset("json", data_files={"train":os.path.join(datapath,"ltl_eng_train_mid_ascii_gptAuged.jsonl"),"test":os.path.join(datapath,"ltl_eng_test_mid_ascii_gptAuged.jsonl")}) |
|
print(tokenized_dataset) |
|
|
|
|
|
import re |
|
|
|
predictions, references,input_sentence,output_sentence=[], [] , [], [] |
|
|
|
for idx in range(len(tokenized_dataset['test']['natural'])): |
|
|
|
nl=tokenized_dataset['test']['natural'][idx] |
|
transLTL=translater.translate(nl) |
|
|
|
|
|
input_sentence.append(nl) |
|
|
|
|
|
|
|
predictions.append(transLTL) |
|
|
|
|
|
references.append(tokenized_dataset['test']['raw_ltl'][idx].strip()) |
|
print(idx,'\n',input_sentence[-1], |
|
|
|
'\npre::\n',predictions[-1], |
|
'\nref::\n',references[-1],'\n','-'*20,'\n') |
|
|
|
|
|
rogue = metric.compute(predictions=predictions, references=references, use_stemmer=True) |
|
|
|
|
|
print(f"Rogue1: {rogue['rouge1']* 100:2f}%") |
|
print(f"rouge2: {rogue['rouge2']* 100:2f}%") |
|
print(f"rougeL: {rogue['rougeL']* 100:2f}%") |
|
print(f"rougeLsum: {rogue['rougeLsum']* 100:2f}%") |
|
eval_output=np.array([input_sentence,predictions,references]).T |
|
import pandas as pd |
|
eval_output=pd.DataFrame(eval_output) |
|
pd.DataFrame.to_csv(eval_output,"path/to/model_weight/mistral7b_mid_ascii_0327_eos_2aug1_quat8"+'/output') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
exit() |
|
flag=True |
|
while flag: |
|
lines=[""] |
|
try: |
|
lines.append(input()) |
|
while True: |
|
lines.append(input()) |
|
except: |
|
pass |
|
ret ="".join(lines) |
|
print(ret) |
|
if ret=="": |
|
flag=False |
|
|
|
print(translater.translate(ret)) |
|
|
|
|
|
|