tt-dart's picture
update readme
d834d9d
# from huggingface_hub import login
# login()
import sys,os
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
# from peft import LoraConfig
# from trl import SFTTrainer
# from accelerate import infer_auto_device_map,init_empty_weights
# sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
from NL2HLTLTranslator.utils.util import Task2Preplacer
from NL2HLTLTranslator.utils.util import LTLChecker
import re
from datasets import concatenate_datasets
import numpy as np
from peft import AutoPeftModelForCausalLM
os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
# os.environ['CUDA_VISIBLE_DEVICES']='3'
class Mistral_NL2TL_translator():
def __init__(self,
output_dir = os.path.join(os.path.dirname(__file__),'../../'),
tuned_model_name="mistral7b_quat8",
# CUDA_device='0',
quat=True,
replacer=Task2Preplacer) -> None:
# os.environ['CUDA_VISIBLE_DEVICES']=CUDA_device
self.device_map="auto"
self.model_dir = os.path.join(output_dir, tuned_model_name)
# check
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# AutoPeftModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf")
# quantconfig = BitsAndBytesConfig(
# load_in_8bit=True,
# bnb_8bit_quant_type="nf4",
# bnb_8bit_use_double_quant=True,
# bnb_8bit_compute_dtype=torch.bfloat16,
# )
# if quat==False:
# self.model = AutoPeftModelForCausalLM.from_pretrained(self.output_dir, device_map=self.device_map, torch_dtype=torch.bfloat16)
# # ICL super man可以不量化
# else:
# self.model = AutoPeftModelForCausalLM.from_pretrained(self.output_dir,device_map=self.device_map, torch_dtype=torch.float16,
# load_in_8bit=True)
# # quantization_config=quantconfig)
self.bnb_config = BitsAndBytesConfig(
load_in_4bit = True,
bnb_4bit_use_double_quant = False,
bnb_4bit_quant_type = 'nf4',
bnb_4bit_compute_dtype = getattr(torch, "float16")
)
self.bnb_config = BitsAndBytesConfig(
load_in_8bit = True,
# llm_int8_threshold=200.0
# bnb_4bit_use_double_quant = False,
# bnb_4bit_quant_type = 'nf4',
# bnb_4bit_compute_dtype = getattr(torch, "float16")
)
# self.bnb_config = BitsAndBytesConfig(
# load_in_8bit = False,
# load_in_4bit = False,
# # llm_int8_threshold=200.0
# # bnb_4bit_use_double_quant = False,
# # bnb_4bit_quant_type = 'nf4',
# # bnb_4bit_compute_dtype = getattr(torch, "float16")
# )
self.model = AutoModelForCausalLM.from_pretrained(
self.model_dir,
from_tf=bool(".ckpt" in self.model_dir),
quantization_config=self.bnb_config,
device_map=self.device_map,
trust_remote_code=True,
use_auth_token=True
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
# , trust_remote_code=True,add_eos_token=True,)
# tokenizer = AutoTokenizer.from_pretrained(base_model_name, add_eos_token=True,trust_remote_code=True)
# NOTE no one says whether the add eos token need to be added, but if we do not add this, the generate will continue until reach the max_new_tokens,
# when in predict model, do not use the add_eos_token=True, as the tokenizer will automatically add <\s> to the input, and thus the output will be inregular
# when add add_eos_token, it always failed
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = 'right'
print(self.tokenizer.eos_token_id)
# 2
print(self.tokenizer.bos_token_id)
# 1
# print(tokenizer._convert_token_to_id(tokenizer.bos_token))
print("NL2TL model loaded")
self.replacer=replacer
self.ltlChecker=LTLChecker()
pass
# print('NL2TL llama translate test:')
# self.translate("Task_1.1 must be done, and Task_1.2 should be finished before Task_1.1")
def evaluate_model(self, input_text):
self.pattern=re.compile("linear temproal logic is ([\S ]*).")
messages=[
{"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical way, and then translate into linear temproal logic, pay specific attention to brackets '()', natural language task: {}".format(input_text.strip())},
]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device)
outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id)
p=self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print('model output:',p)
transLTL=self.pattern.findall(p)[0]
if transLTL[-1]=='.':
transLTL=transLTL[:-1].strip()
else:
transLTL=transLTL.strip()
transLTL=self.ltlChecker.right_barkets_remover(transLTL)
print('transLTL:\n',transLTL)
return transLTL
def evaluate_model2(self, input_text):
self.pattern=re.compile("LTL is ([\S ]*).")
messages=[
{"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, the natural language task is {}".format(input_text.strip())},
]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device)
outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id)
p=self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print('---model output 1:\n',p)
# messages=[
# {"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, the natural language task is {}".format(input_text.strip())},
# {"role": "assistant", "content":p
# },
# {"role": "user", "content": " pay specific attention to brackets '()', given your linear temproal logic translation"},
# ]
# encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device)
# outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id)
# p=self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# print('---model output 2:\n',p)
transLTL=self.pattern.findall(p)[0]
if transLTL[-1]=='.':
transLTL=transLTL[:-1].strip()
else:
transLTL=transLTL.strip()
transLTL=self.ltlChecker.right_barkets_remover(transLTL)
print('transLTL:\n',transLTL)
return transLTL
def evaluate_model3(self, input_text):
# "LTL is a larger language model . . . . . . "
# self.pattern=re.compile("LTL is ([\S ]*)\.")
self.pattern=re.compile("LTL is ([^\.]*)\.")
messages=[
{"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, please pay specific attention to logic grammar, the natural language task is {}".format(input_text.strip())},
]
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device)
outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id)
p=self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print('---model output 1:\n',p)
# messages=[
# {"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, the natural language task is {}".format(input_text.strip())},
# {"role": "assistant", "content":p
# },
# {"role": "user", "content": " pay specific attention to brackets '()', given your linear temproal logic translation"},
# ]
# encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device)
# outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id)
# p=self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# print('---model output 2:\n',p)
transLTL=self.pattern.findall(p)
if len(transLTL)==0:
return False
transLTL=transLTL[0]
if transLTL[-1]=='.':
transLTL=transLTL[:-1].strip()
else:
transLTL=transLTL.strip()
transLTL=self.ltlChecker.right_barkets_remover(transLTL)
print('transLTL:\n',transLTL)
return transLTL
def translate(self,input_prompt:str=""):
print('input_prompt:\n',input_prompt)
replacer=self.replacer()
input_prompt=replacer.reTask2P(input_prompt)
# print(predicter( replace.reTask2P(input_prompt)))
# print(input_prompt)
# print(p)
flag_check_false_count=0
flag_check=False
while not flag_check and flag_check_false_count<10:
flag_check_false_count+=1
flag_check=True
transLTL=self.evaluate_model3(input_prompt)
transLTL=transLTL.replace('Or','And')
transLTL=transLTL.replace('Globally','Finally')
if isinstance(transLTL,bool):
flag_check=False
elif not self.ltlChecker.AP_CorrCheck(input_prompt,transLTL):
print('AP_CorrCheck false')
flag_check=False
elif not self.ltlChecker.brackets_Check(transLTL):
print('brackets_Check false')
flag_check=False
# print(p)
return replacer.reP2Task(transLTL)
if __name__=="__main__":
# translater=Mistral_NL2TL_translator()
# test_prompts=[
# "Task_1.1.1 must precede Task_1.1.2, which in turn should precede Task_1.1.3, ",
# "Task_1.1 must be completed before Task_1.2 starts, and Task_1.2 must be completed before Task_1.3 starts." ,
# "Task_1.1 can be executed independently, after which Task_1.2 can be executed.",
# "Task_1.2.4 must be completed first, followed by Task_1.2.2, then Task_1.2.3, and finally Task_1.2.1.",
# "Task_1.2.4 is always executed first, followed by Task_1.2.3, then Task_1.2.2, and finally Task_1.2.1.",
# "Task_1.2.1 and Task_1.2.2 can be executed independently, and both should eventually be completed.",
# ]
# for ret in test_prompts:
# print(translater.translate(ret))
# print('\n','-'*20,'\n')
# exit()
class p2preplacer():
def reTask2P(self,input):
return input
def reP2Task(self,input):
return input
translater=Mistral_NL2TL_translator(replacer=p2preplacer)
import evaluate
import numpy as np
# from datasets import load_from_disk
from tqdm import tqdm
# Metric
metric = evaluate.load("rouge")
datapath='path/to/NL2TL-dataset/collect2'
tokenized_dataset = load_dataset("json", data_files={"train":os.path.join(datapath,"ltl_eng_train_mid_ascii_gptAuged.jsonl"),"test":os.path.join(datapath,"ltl_eng_test_mid_ascii_gptAuged.jsonl")})
print(tokenized_dataset)
# run predictions
# this can take ~45 minutes
import re
# pattern=re.compile("\[Formal LTL\]:\n([\S ]*)\n")
predictions, references,input_sentence,output_sentence=[], [] , [], []
# with open()
for idx in range(len(tokenized_dataset['test']['natural'])):
# print(sample)
nl=tokenized_dataset['test']['natural'][idx]
transLTL=translater.translate(nl)
# p = translater.evaluate_model(nl)
# # print(p,l)
input_sentence.append(nl)
# transLTL=pattern.findall(p)
# # print(p)
predictions.append(transLTL)
# output_sentence.append(p)
# input_sentence.append(nl)
references.append(tokenized_dataset['test']['raw_ltl'][idx].strip())
print(idx,'\n',input_sentence[-1],
# '\nout::\n',output_sentence[-1],
'\npre::\n',predictions[-1],
'\nref::\n',references[-1],'\n','-'*20,'\n')
# compute metric
rogue = metric.compute(predictions=predictions, references=references, use_stemmer=True)
# print results
print(f"Rogue1: {rogue['rouge1']* 100:2f}%")
print(f"rouge2: {rogue['rouge2']* 100:2f}%")
print(f"rougeL: {rogue['rougeL']* 100:2f}%")
print(f"rougeLsum: {rogue['rougeLsum']* 100:2f}%")
eval_output=np.array([input_sentence,predictions,references]).T
import pandas as pd
eval_output=pd.DataFrame(eval_output)
pd.DataFrame.to_csv(eval_output,"path/to/model_weight/mistral7b_mid_ascii_0327_eos_2aug1_quat8"+'/output')
# out llama
# Rogue1: 98.363321%
# rouge2: 95.987820%
# rougeL: 97.384820%
# rougeLsum: 97.382071%
# this
# Rogue1: 98.543297%
# rouge2: 96.575248%
# rougeL: 97.720560%
# rougeLsum: 97.724880%
exit()
flag=True
while flag:
lines=[""]
try:
lines.append(input())
while True:
lines.append(input())
except:
pass
ret ="".join(lines)
print(ret)
if ret=="":
flag=False
print(translater.translate(ret))