File size: 3,834 Bytes
bacb17b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# %%
from transformers import (AutoModelForSeq2SeqLM, 
                            AutoTokenizer, 
                            T5Tokenizer)
import torch
import pandas as pd
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from tqdm import tqdm

#import subprocess
import sys
import os
import argparse
# from IPython.core import error
import random
import numpy as np
import nltk
import json
import csv
import utils.util as util

# run under conda env minigpt4

class NL2TL():
	def __init__(self,dirpath='outputdir/') -> None:
		self.output_dir = dirpath
		
		# Here you need to link this path in your Google drive to the place preseving your model weights, e.g., checkpoint-62500
		# You can download it on the github page

		self.model_checkpoint = "t5-base"
		self.prefix = "Transform the following sentence into Signal Temporal logic: "

		self.max_input_length = 1024
		self.max_target_length = 128
		self.tokenizer = AutoTokenizer.from_pretrained(self.model_checkpoint, model_max_length=self.max_input_length)

		self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
		self.tl_model = AutoModelForSeq2SeqLM.from_pretrained(self.output_dir+"checkpoint-62500").to(self.device)

		# %%
		import time
		self.time_start = time.time()
		self.inputs = [self.prefix + 'At some point (prop_1), and at some point (prop_2), and always do not (prop_4).']
		self.inputs = self.tokenizer(self.inputs, max_length=self.max_input_length, truncation=True, return_tensors="pt").to(self.device)
		self.output = self.tl_model.generate(**self.inputs, num_beams=8, do_sample=True, max_length=self.max_target_length)
		self.decoded_output = self.tokenizer.batch_decode(self.output, skip_special_tokens=True)[0]
		print(self.decoded_output)
		self.time_end = time.time()
		print('Translation time: ', self.time_end-self.time_start)
		print('\nNL2TL init\n')
		self.splitJSONfromTXT=util.splitJSONfromTXT
		# %%
		# Here are the example test sentences
		pass
	def translate(self,inputNLtxt:str=""):
		inputNLtxt=inputNLtxt.replace("Task_","prop_")

		sentence=inputNLtxt
		self.inputs = [self.prefix + sentence]
		self.inputs = self.tokenizer(self.inputs, max_length=self.max_input_length, truncation=True, return_tensors="pt").to(self.device)
		self.output = self.tl_model.generate(**self.inputs, num_beams=8, do_sample=True, max_length=self.max_target_length)
		self.decoded_output = self.tokenizer.batch_decode(self.output, skip_special_tokens=True)[0]
		print('Input sentence: ', sentence)
		print('Translated STL: ', self.decoded_output)
		print('\n')
		            
		self.decoded_output=self.decoded_output.replace('prop_','Task_')
		return self.decoded_output
	def waiting(self):
		retry=True
		while retry:
			inputNL=util.GPTinterface("continue next")
			if inputNL!="q":
				Json=self.splitJSONfromTXT(inputNL)
				print(Json)
				jsonTree=json.loads("{"+Json[-1]+"}")
				input_NL=jsonTree["LTL_description"].replace("Task_","prop_")
				output_TL=self.translate(input_NL)
				output_TL=output_TL.replace('prop_','Task_')
				print("\n",output_TL,"\n")
			else:
				retry =False
if __name__=="__main__":
	# examples=['Stay at (prop_1) for 5 units in the future and stay at (prop_2) for 5 units in the future, and ensure that never (prop_3).',
	# 'First (prop_1), and then (prop_2), and ensure that never (prop_3).',
	# 'Start by (prop_1). Then, (prop_2). Lastly, (prop_3).',
	# 'Guarantee that you (prop_1) and (prop_2)', # Input the natural sentence
	# '( prop_1 ) and whenever ( prop_2 )',
	# 'Sooner or later (prop_1)',
	# 'Repeatedly (prop_1)',
	# 'At some point, (prop_1).',
	# 'Do prop_1 but not do prop_2',
	# 'Do prop_1, do prop_2, do prop_3'] # Input the natural sentence
	# interface=NL2TL()
	# for txt in examples:
	# 	interface.translate(txt)

	interface=NL2TL()
	interface.waiting()