|  |  | 
					
						
						|  | from transformers import (AutoModelForSeq2SeqLM, | 
					
						
						|  | AutoTokenizer, | 
					
						
						|  | T5Tokenizer) | 
					
						
						|  | import torch | 
					
						
						|  | import pandas as pd | 
					
						
						|  | from datasets import Dataset, DatasetDict, load_dataset, load_from_disk | 
					
						
						|  | from tqdm import tqdm | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import sys | 
					
						
						|  | import os | 
					
						
						|  | import argparse | 
					
						
						|  |  | 
					
						
						|  | import random | 
					
						
						|  | import numpy as np | 
					
						
						|  | import nltk | 
					
						
						|  | import json | 
					
						
						|  | import csv | 
					
						
						|  | import utils.util as util | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class NL2TL(): | 
					
						
						|  | def __init__(self,dirpath='outputdir/') -> None: | 
					
						
						|  | self.output_dir = dirpath | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | self.model_checkpoint = "t5-base" | 
					
						
						|  | self.prefix = "Transform the following sentence into Signal Temporal logic: " | 
					
						
						|  |  | 
					
						
						|  | self.max_input_length = 1024 | 
					
						
						|  | self.max_target_length = 128 | 
					
						
						|  | self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint, model_max_length=self.max_input_length) | 
					
						
						|  |  | 
					
						
						|  | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | 
					
						
						|  | self.tl_model = AutoModelForSeq2SeqLM.from_pretrained(self.output_dir+"checkpoint-62500").to(self.device) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | import time | 
					
						
						|  | self.time_start = time.time() | 
					
						
						|  | self.inputs = [self.prefix + 'At some point (prop_1), and at some point (prop_2), and always do not (prop_4).'] | 
					
						
						|  | self.inputs = self.tokenizer(self.inputs, max_length=self.max_input_length, truncation=True, return_tensors="pt").to(self.device) | 
					
						
						|  | self.output = self.tl_model.generate(**self.inputs, num_beams=8, do_sample=True, max_length=self.max_target_length) | 
					
						
						|  | self.decoded_output = self.tokenizer.batch_decode(self.output, skip_special_tokens=True)[0] | 
					
						
						|  | print(self.decoded_output) | 
					
						
						|  | self.time_end = time.time() | 
					
						
						|  | print('Translation time: ', self.time_end-self.time_start) | 
					
						
						|  | print('\nNL2TL init\n') | 
					
						
						|  | self.splitJSONfromTXT=util.splitJSONfromTXT | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pass | 
					
						
						|  | def translate(self,inputNLtxt:str=""): | 
					
						
						|  | inputNLtxt=inputNLtxt.replace("Task_","prop_") | 
					
						
						|  |  | 
					
						
						|  | sentence=inputNLtxt | 
					
						
						|  | self.inputs = [self.prefix + sentence] | 
					
						
						|  | self.inputs = self.tokenizer(self.inputs, max_length=self.max_input_length, truncation=True, return_tensors="pt").to(self.device) | 
					
						
						|  | self.output = self.tl_model.generate(**self.inputs, num_beams=8, do_sample=True, max_length=self.max_target_length) | 
					
						
						|  | self.decoded_output = self.tokenizer.batch_decode(self.output, skip_special_tokens=True)[0] | 
					
						
						|  | print('Input sentence: ', sentence) | 
					
						
						|  | print('Translated STL: ', self.decoded_output) | 
					
						
						|  | print('\n') | 
					
						
						|  |  | 
					
						
						|  | self.decoded_output=self.decoded_output.replace('prop_','Task_') | 
					
						
						|  | return self.decoded_output | 
					
						
						|  | def waiting(self): | 
					
						
						|  | retry=True | 
					
						
						|  | while retry: | 
					
						
						|  | inputNL=util.GPTinterface("continue next") | 
					
						
						|  | if inputNL!="q": | 
					
						
						|  | Json=self.splitJSONfromTXT(inputNL) | 
					
						
						|  | print(Json) | 
					
						
						|  | jsonTree=json.loads("{"+Json[-1]+"}") | 
					
						
						|  | input_NL=jsonTree["LTL_description"].replace("Task_","prop_") | 
					
						
						|  | output_TL=self.translate(input_NL) | 
					
						
						|  | output_TL=output_TL.replace('prop_','Task_') | 
					
						
						|  | print("\n",output_TL,"\n") | 
					
						
						|  | else: | 
					
						
						|  | retry =False | 
					
						
						|  | if __name__=="__main__": | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | interface=NL2TL() | 
					
						
						|  | interface.waiting() | 
					
						
						|  |  | 
					
						
						|  |  |