tt-dart's picture
update readme
d834d9d
# %%
from transformers import (AutoModelForSeq2SeqLM,
AutoTokenizer,
T5Tokenizer)
import torch
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from tqdm import tqdm
#import subprocess
import sys
import os
import argparse
# from IPython.core import error
import random
import numpy as np
import nltk
import json
import csv
import utils.util as util
# run under conda env minigpt4
class NL2TL():
def __init__(self,dirpath='outputdir/') -> None:
self.output_dir = dirpath
# Here you need to link this path in your Google drive to the place preseving your model weights, e.g., checkpoint-62500
# You can download it on the github page
self.model_checkpoint = "t5-base"
self.prefix = "Transform the following sentence into Signal Temporal logic: "
self.max_input_length = 1024
self.max_target_length = 128
self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint, model_max_length=self.max_input_length)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.tl_model = AutoModelForSeq2SeqLM.from_pretrained(self.output_dir+"checkpoint-62500").to(self.device)
# %%
import time
self.time_start = time.time()
self.inputs = [self.prefix + 'At some point (prop_1), and at some point (prop_2), and always do not (prop_4).']
self.inputs = self.tokenizer(self.inputs, max_length=self.max_input_length, truncation=True, return_tensors="pt").to(self.device)
self.output = self.tl_model.generate(**self.inputs, num_beams=8, do_sample=True, max_length=self.max_target_length)
self.decoded_output = self.tokenizer.batch_decode(self.output, skip_special_tokens=True)[0]
print(self.decoded_output)
self.time_end = time.time()
print('Translation time: ', self.time_end-self.time_start)
print('\nNL2TL init\n')
self.splitJSONfromTXT=util.splitJSONfromTXT
# %%
# Here are the example test sentences
pass
def translate(self,inputNLtxt:str=""):
inputNLtxt=inputNLtxt.replace("Task_","prop_")
sentence=inputNLtxt
self.inputs = [self.prefix + sentence]
self.inputs = self.tokenizer(self.inputs, max_length=self.max_input_length, truncation=True, return_tensors="pt").to(self.device)
self.output = self.tl_model.generate(**self.inputs, num_beams=8, do_sample=True, max_length=self.max_target_length)
self.decoded_output = self.tokenizer.batch_decode(self.output, skip_special_tokens=True)[0]
print('Input sentence: ', sentence)
print('Translated STL: ', self.decoded_output)
print('\n')
self.decoded_output=self.decoded_output.replace('prop_','Task_')
return self.decoded_output
def waiting(self):
retry=True
while retry:
inputNL=util.GPTinterface("continue next")
if inputNL!="q":
Json=self.splitJSONfromTXT(inputNL)
print(Json)
jsonTree=json.loads("{"+Json[-1]+"}")
input_NL=jsonTree["LTL_description"].replace("Task_","prop_")
output_TL=self.translate(input_NL)
output_TL=output_TL.replace('prop_','Task_')
print("\n",output_TL,"\n")
else:
retry =False
if __name__=="__main__":
# examples=['Stay at (prop_1) for 5 units in the future and stay at (prop_2) for 5 units in the future, and ensure that never (prop_3).',
# 'First (prop_1), and then (prop_2), and ensure that never (prop_3).',
# 'Start by (prop_1). Then, (prop_2). Lastly, (prop_3).',
# 'Guarantee that you (prop_1) and (prop_2)', # Input the natural sentence
# '( prop_1 ) and whenever ( prop_2 )',
# 'Sooner or later (prop_1)',
# 'Repeatedly (prop_1)',
# 'At some point, (prop_1).',
# 'Do prop_1 but not do prop_2',
# 'Do prop_1, do prop_2, do prop_3'] # Input the natural sentence
# interface=NL2TL()
# for txt in examples:
# interface.translate(txt)
interface=NL2TL()
interface.waiting()