|
|
|
from transformers import (AutoModelForSeq2SeqLM, |
|
AutoTokenizer, |
|
T5Tokenizer) |
|
import torch |
|
import pandas as pd |
|
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk |
|
from tqdm import tqdm |
|
|
|
|
|
import sys |
|
import os |
|
import argparse |
|
|
|
import random |
|
import numpy as np |
|
import nltk |
|
import json |
|
import csv |
|
import utils.util as util |
|
|
|
|
|
|
|
class NL2TL(): |
|
def __init__(self,dirpath='outputdir/') -> None: |
|
self.output_dir = dirpath |
|
|
|
|
|
|
|
|
|
self.model_checkpoint = "t5-base" |
|
self.prefix = "Transform the following sentence into Signal Temporal logic: " |
|
|
|
self.max_input_length = 1024 |
|
self.max_target_length = 128 |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint, model_max_length=self.max_input_length) |
|
|
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
self.tl_model = AutoModelForSeq2SeqLM.from_pretrained(self.output_dir+"checkpoint-62500").to(self.device) |
|
|
|
|
|
import time |
|
self.time_start = time.time() |
|
self.inputs = [self.prefix + 'At some point (prop_1), and at some point (prop_2), and always do not (prop_4).'] |
|
self.inputs = self.tokenizer(self.inputs, max_length=self.max_input_length, truncation=True, return_tensors="pt").to(self.device) |
|
self.output = self.tl_model.generate(**self.inputs, num_beams=8, do_sample=True, max_length=self.max_target_length) |
|
self.decoded_output = self.tokenizer.batch_decode(self.output, skip_special_tokens=True)[0] |
|
print(self.decoded_output) |
|
self.time_end = time.time() |
|
print('Translation time: ', self.time_end-self.time_start) |
|
print('\nNL2TL init\n') |
|
self.splitJSONfromTXT=util.splitJSONfromTXT |
|
|
|
|
|
pass |
|
def translate(self,inputNLtxt:str=""): |
|
inputNLtxt=inputNLtxt.replace("Task_","prop_") |
|
|
|
sentence=inputNLtxt |
|
self.inputs = [self.prefix + sentence] |
|
self.inputs = self.tokenizer(self.inputs, max_length=self.max_input_length, truncation=True, return_tensors="pt").to(self.device) |
|
self.output = self.tl_model.generate(**self.inputs, num_beams=8, do_sample=True, max_length=self.max_target_length) |
|
self.decoded_output = self.tokenizer.batch_decode(self.output, skip_special_tokens=True)[0] |
|
print('Input sentence: ', sentence) |
|
print('Translated STL: ', self.decoded_output) |
|
print('\n') |
|
|
|
self.decoded_output=self.decoded_output.replace('prop_','Task_') |
|
return self.decoded_output |
|
def waiting(self): |
|
retry=True |
|
while retry: |
|
inputNL=util.GPTinterface("continue next") |
|
if inputNL!="q": |
|
Json=self.splitJSONfromTXT(inputNL) |
|
print(Json) |
|
jsonTree=json.loads("{"+Json[-1]+"}") |
|
input_NL=jsonTree["LTL_description"].replace("Task_","prop_") |
|
output_TL=self.translate(input_NL) |
|
output_TL=output_TL.replace('prop_','Task_') |
|
print("\n",output_TL,"\n") |
|
else: |
|
retry =False |
|
if __name__=="__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
interface=NL2TL() |
|
interface.waiting() |
|
|
|
|