# from huggingface_hub import login # login() import sys,os from datasets import load_dataset import torch from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments # from peft import LoraConfig # from trl import SFTTrainer # from accelerate import infer_auto_device_map,init_empty_weights # sys.path.append(os.path.join(os.path.dirname(__file__), '../../')) from NL2HLTLTranslator.utils.util import Task2Preplacer from NL2HLTLTranslator.utils.util import LTLChecker import re from datasets import concatenate_datasets import numpy as np from peft import AutoPeftModelForCausalLM os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID' # os.environ['CUDA_VISIBLE_DEVICES']='3' class Mistral_NL2TL_translator(): def __init__(self, output_dir = os.path.join(os.path.dirname(__file__),'../../'), tuned_model_name="mistral7b_quat8", # CUDA_device='0', quat=True, replacer=Task2Preplacer) -> None: # os.environ['CUDA_VISIBLE_DEVICES']=CUDA_device self.device_map="auto" self.model_dir = os.path.join(output_dir, tuned_model_name) # check self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # AutoPeftModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b-hf") # quantconfig = BitsAndBytesConfig( # load_in_8bit=True, # bnb_8bit_quant_type="nf4", # bnb_8bit_use_double_quant=True, # bnb_8bit_compute_dtype=torch.bfloat16, # ) # if quat==False: # self.model = AutoPeftModelForCausalLM.from_pretrained(self.output_dir, device_map=self.device_map, torch_dtype=torch.bfloat16) # # ICL super man可以不量化 # else: # self.model = AutoPeftModelForCausalLM.from_pretrained(self.output_dir,device_map=self.device_map, torch_dtype=torch.float16, # load_in_8bit=True) # # quantization_config=quantconfig) self.bnb_config = BitsAndBytesConfig( load_in_4bit = True, bnb_4bit_use_double_quant = False, bnb_4bit_quant_type = 'nf4', bnb_4bit_compute_dtype = getattr(torch, "float16") ) self.bnb_config = BitsAndBytesConfig( load_in_8bit = True, # llm_int8_threshold=200.0 # bnb_4bit_use_double_quant = False, # bnb_4bit_quant_type = 'nf4', # bnb_4bit_compute_dtype = getattr(torch, "float16") ) # self.bnb_config = BitsAndBytesConfig( # load_in_8bit = False, # load_in_4bit = False, # # llm_int8_threshold=200.0 # # bnb_4bit_use_double_quant = False, # # bnb_4bit_quant_type = 'nf4', # # bnb_4bit_compute_dtype = getattr(torch, "float16") # ) self.model = AutoModelForCausalLM.from_pretrained( self.model_dir, from_tf=bool(".ckpt" in self.model_dir), quantization_config=self.bnb_config, device_map=self.device_map, trust_remote_code=True, use_auth_token=True ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) # , trust_remote_code=True,add_eos_token=True,) # tokenizer = AutoTokenizer.from_pretrained(base_model_name, add_eos_token=True,trust_remote_code=True) # NOTE no one says whether the add eos token need to be added, but if we do not add this, the generate will continue until reach the max_new_tokens, # when in predict model, do not use the add_eos_token=True, as the tokenizer will automatically add <\s> to the input, and thus the output will be inregular # when add add_eos_token, it always failed self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.padding_side = 'right' print(self.tokenizer.eos_token_id) # 2 print(self.tokenizer.bos_token_id) # 1 # print(tokenizer._convert_token_to_id(tokenizer.bos_token)) print("NL2TL model loaded") self.replacer=replacer self.ltlChecker=LTLChecker() pass # print('NL2TL llama translate test:') # self.translate("Task_1.1 must be done, and Task_1.2 should be finished before Task_1.1") def evaluate_model(self, input_text): self.pattern=re.compile("linear temproal logic is ([\S ]*).") messages=[ {"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical way, and then translate into linear temproal logic, pay specific attention to brackets '()', natural language task: {}".format(input_text.strip())}, ] encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device) outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id) p=self.tokenizer.decode(outputs[0], skip_special_tokens=True) print('model output:',p) transLTL=self.pattern.findall(p)[0] if transLTL[-1]=='.': transLTL=transLTL[:-1].strip() else: transLTL=transLTL.strip() transLTL=self.ltlChecker.right_barkets_remover(transLTL) print('transLTL:\n',transLTL) return transLTL def evaluate_model2(self, input_text): self.pattern=re.compile("LTL is ([\S ]*).") messages=[ {"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, the natural language task is {}".format(input_text.strip())}, ] encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device) outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id) p=self.tokenizer.decode(outputs[0], skip_special_tokens=True) print('---model output 1:\n',p) # messages=[ # {"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, the natural language task is {}".format(input_text.strip())}, # {"role": "assistant", "content":p # }, # {"role": "user", "content": " pay specific attention to brackets '()', given your linear temproal logic translation"}, # ] # encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device) # outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id) # p=self.tokenizer.decode(outputs[0], skip_special_tokens=True) # print('---model output 2:\n',p) transLTL=self.pattern.findall(p)[0] if transLTL[-1]=='.': transLTL=transLTL[:-1].strip() else: transLTL=transLTL.strip() transLTL=self.ltlChecker.right_barkets_remover(transLTL) print('transLTL:\n',transLTL) return transLTL def evaluate_model3(self, input_text): # "LTL is a larger language model . . . . . . " # self.pattern=re.compile("LTL is ([\S ]*)\.") self.pattern=re.compile("LTL is ([^\.]*)\.") messages=[ {"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, please pay specific attention to logic grammar, the natural language task is {}".format(input_text.strip())}, ] encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device) outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id) p=self.tokenizer.decode(outputs[0], skip_special_tokens=True) print('---model output 1:\n',p) # messages=[ # {"role": "user", "content": "translate natural description to linear temproal logic, first translate into a logical expression, and then translate into linear temproal logic, the natural language task is {}".format(input_text.strip())}, # {"role": "assistant", "content":p # }, # {"role": "user", "content": " pay specific attention to brackets '()', given your linear temproal logic translation"}, # ] # encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt").to(self.device) # outputs = self.model.generate(encodeds, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id) # p=self.tokenizer.decode(outputs[0], skip_special_tokens=True) # print('---model output 2:\n',p) transLTL=self.pattern.findall(p) if len(transLTL)==0: return False transLTL=transLTL[0] if transLTL[-1]=='.': transLTL=transLTL[:-1].strip() else: transLTL=transLTL.strip() transLTL=self.ltlChecker.right_barkets_remover(transLTL) print('transLTL:\n',transLTL) return transLTL def translate(self,input_prompt:str=""): print('input_prompt:\n',input_prompt) replacer=self.replacer() input_prompt=replacer.reTask2P(input_prompt) # print(predicter( replace.reTask2P(input_prompt))) # print(input_prompt) # print(p) flag_check_false_count=0 flag_check=False while not flag_check and flag_check_false_count<10: flag_check_false_count+=1 flag_check=True transLTL=self.evaluate_model3(input_prompt) transLTL=transLTL.replace('Or','And') transLTL=transLTL.replace('Globally','Finally') if isinstance(transLTL,bool): flag_check=False elif not self.ltlChecker.AP_CorrCheck(input_prompt,transLTL): print('AP_CorrCheck false') flag_check=False elif not self.ltlChecker.brackets_Check(transLTL): print('brackets_Check false') flag_check=False # print(p) return replacer.reP2Task(transLTL) if __name__=="__main__": # translater=Mistral_NL2TL_translator() # test_prompts=[ # "Task_1.1.1 must precede Task_1.1.2, which in turn should precede Task_1.1.3, ", # "Task_1.1 must be completed before Task_1.2 starts, and Task_1.2 must be completed before Task_1.3 starts." , # "Task_1.1 can be executed independently, after which Task_1.2 can be executed.", # "Task_1.2.4 must be completed first, followed by Task_1.2.2, then Task_1.2.3, and finally Task_1.2.1.", # "Task_1.2.4 is always executed first, followed by Task_1.2.3, then Task_1.2.2, and finally Task_1.2.1.", # "Task_1.2.1 and Task_1.2.2 can be executed independently, and both should eventually be completed.", # ] # for ret in test_prompts: # print(translater.translate(ret)) # print('\n','-'*20,'\n') # exit() class p2preplacer(): def reTask2P(self,input): return input def reP2Task(self,input): return input translater=Mistral_NL2TL_translator(replacer=p2preplacer) import evaluate import numpy as np # from datasets import load_from_disk from tqdm import tqdm # Metric metric = evaluate.load("rouge") datapath='path/to/NL2TL-dataset/collect2' tokenized_dataset = load_dataset("json", data_files={"train":os.path.join(datapath,"ltl_eng_train_mid_ascii_gptAuged.jsonl"),"test":os.path.join(datapath,"ltl_eng_test_mid_ascii_gptAuged.jsonl")}) print(tokenized_dataset) # run predictions # this can take ~45 minutes import re # pattern=re.compile("\[Formal LTL\]:\n([\S ]*)\n") predictions, references,input_sentence,output_sentence=[], [] , [], [] # with open() for idx in range(len(tokenized_dataset['test']['natural'])): # print(sample) nl=tokenized_dataset['test']['natural'][idx] transLTL=translater.translate(nl) # p = translater.evaluate_model(nl) # # print(p,l) input_sentence.append(nl) # transLTL=pattern.findall(p) # # print(p) predictions.append(transLTL) # output_sentence.append(p) # input_sentence.append(nl) references.append(tokenized_dataset['test']['raw_ltl'][idx].strip()) print(idx,'\n',input_sentence[-1], # '\nout::\n',output_sentence[-1], '\npre::\n',predictions[-1], '\nref::\n',references[-1],'\n','-'*20,'\n') # compute metric rogue = metric.compute(predictions=predictions, references=references, use_stemmer=True) # print results print(f"Rogue1: {rogue['rouge1']* 100:2f}%") print(f"rouge2: {rogue['rouge2']* 100:2f}%") print(f"rougeL: {rogue['rougeL']* 100:2f}%") print(f"rougeLsum: {rogue['rougeLsum']* 100:2f}%") eval_output=np.array([input_sentence,predictions,references]).T import pandas as pd eval_output=pd.DataFrame(eval_output) pd.DataFrame.to_csv(eval_output,"path/to/model_weight/mistral7b_mid_ascii_0327_eos_2aug1_quat8"+'/output') # out llama # Rogue1: 98.363321% # rouge2: 95.987820% # rougeL: 97.384820% # rougeLsum: 97.382071% # this # Rogue1: 98.543297% # rouge2: 96.575248% # rougeL: 97.720560% # rougeLsum: 97.724880% exit() flag=True while flag: lines=[""] try: lines.append(input()) while True: lines.append(input()) except: pass ret ="".join(lines) print(ret) if ret=="": flag=False print(translater.translate(ret))