tt-dart's picture
update readme
d834d9d
import json
import numpy as np
import os,sys
# sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))
from ... import utils as util
class DataPreprocess():
def __init__(self,data_path="LTL_datasets/collect") -> None:
self.data_path=data_path
self.train_valid_split=0.1
pass
def txtdataReader(self):
LTL_list=[
# 'ltl_mid_order_ascii.txt',
# 'ltl_mid_order_ascii_gpt_auged.txt',
# 'ltl_mid_order_ascii_gpt_auged2.txt',
'Cleaned_LTL.txt'
]
ENG_list=[
# 'eng.txt',
# 'eng_gpt_auged.txt',
# 'eng_gpt_auged2.txt'
'Cleaned_ENG.txt'
]
content=[]
for filename in LTL_list:
with open(os.path.join(self.data_path,filename)) as txt:
content += txt.readlines()
txt.close()
self.ltl =np.array(content)
content=[]
for filename in ENG_list:
with open(os.path.join(self.data_path,filename)) as txt:
content += txt.readlines()
txt.close()
self.eng =np.array(content)
print(len(self.ltl))
def JSONdataCreate(self):
self.txtdataReader()
self.JSONWriter()
def JSONWriter(self):
np.random.seed(42)
# idx=np.random.shuffle( np.arange(len(ltl)))
self.idx=np.arange(len(self.ltl))
np.random.shuffle(self.idx)
with open(self.data_path+"/ltl_eng_train_mid_ascii_gptAuged.jsonl","w") as f:
for i in range(int(len(self.ltl)*(1-self.train_valid_split))):
json.dump({"natural":self.eng[self.idx[i]],"raw_ltl":self.ltl[self.idx[i]],"id":str(self.idx[i])},f)
f.write('\n')
with open(self.data_path+"/ltl_eng_test_mid_ascii_gptAuged.jsonl","w") as f:
for i in range(int(len(self.ltl)*(1-self.train_valid_split)),len(self.ltl)):
json.dump({"natural":self.eng[self.idx[i]],"raw_ltl":self.ltl[self.idx[i]],"id":str(self.idx[i])},f)
f.write('\n')
def dataCheck(self):
self.txtdataReader()
checker=util.LTLChecker()
with open(os.path.join(self.data_path,"Cleaned_LTL.txt"),"a") as passed_LTL:
with open(os.path.join(self.data_path,"Cleaned_ENG.txt"),"a") as passed_ENG:
with open(os.path.join(self.data_path,"UNCleaned_num.txt"),"a") as unpassed_row:
with open(os.path.join(self.data_path,"UNCleaned_LTL.txt"),"a") as unpassed_LTL:
with open(os.path.join(self.data_path,"UNCleaned_ENG.txt"),"a") as unpassed_ENG:
for id in range(len(self.ltl)):
if checker.AP_CorrCheck(self.ltl[id],self.eng[id]):
passed_LTL.write(self.ltl[id])
passed_ENG.write(self.eng[id])
else:
unpassed_row.write("{}\n".format(id))
unpassed_LTL.write(self.ltl[id])
unpassed_ENG.write(self.eng[id])
if __name__=="__main__":
# DataPreprocess().dataCheck()
DataPreprocess().JSONdataCreate()