|
import json |
|
import numpy as np |
|
import os,sys |
|
|
|
from ... import utils as util |
|
class DataPreprocess(): |
|
def __init__(self,data_path="LTL_datasets/collect") -> None: |
|
self.data_path=data_path |
|
self.train_valid_split=0.1 |
|
pass |
|
def txtdataReader(self): |
|
LTL_list=[ |
|
|
|
|
|
|
|
'Cleaned_LTL.txt' |
|
] |
|
ENG_list=[ |
|
|
|
|
|
|
|
'Cleaned_ENG.txt' |
|
] |
|
content=[] |
|
for filename in LTL_list: |
|
with open(os.path.join(self.data_path,filename)) as txt: |
|
content += txt.readlines() |
|
txt.close() |
|
self.ltl =np.array(content) |
|
|
|
content=[] |
|
for filename in ENG_list: |
|
with open(os.path.join(self.data_path,filename)) as txt: |
|
content += txt.readlines() |
|
txt.close() |
|
self.eng =np.array(content) |
|
print(len(self.ltl)) |
|
|
|
def JSONdataCreate(self): |
|
self.txtdataReader() |
|
self.JSONWriter() |
|
|
|
def JSONWriter(self): |
|
np.random.seed(42) |
|
|
|
self.idx=np.arange(len(self.ltl)) |
|
np.random.shuffle(self.idx) |
|
with open(self.data_path+"/ltl_eng_train_mid_ascii_gptAuged.jsonl","w") as f: |
|
for i in range(int(len(self.ltl)*(1-self.train_valid_split))): |
|
json.dump({"natural":self.eng[self.idx[i]],"raw_ltl":self.ltl[self.idx[i]],"id":str(self.idx[i])},f) |
|
f.write('\n') |
|
with open(self.data_path+"/ltl_eng_test_mid_ascii_gptAuged.jsonl","w") as f: |
|
for i in range(int(len(self.ltl)*(1-self.train_valid_split)),len(self.ltl)): |
|
json.dump({"natural":self.eng[self.idx[i]],"raw_ltl":self.ltl[self.idx[i]],"id":str(self.idx[i])},f) |
|
f.write('\n') |
|
|
|
def dataCheck(self): |
|
self.txtdataReader() |
|
checker=util.LTLChecker() |
|
with open(os.path.join(self.data_path,"Cleaned_LTL.txt"),"a") as passed_LTL: |
|
with open(os.path.join(self.data_path,"Cleaned_ENG.txt"),"a") as passed_ENG: |
|
with open(os.path.join(self.data_path,"UNCleaned_num.txt"),"a") as unpassed_row: |
|
with open(os.path.join(self.data_path,"UNCleaned_LTL.txt"),"a") as unpassed_LTL: |
|
with open(os.path.join(self.data_path,"UNCleaned_ENG.txt"),"a") as unpassed_ENG: |
|
for id in range(len(self.ltl)): |
|
if checker.AP_CorrCheck(self.ltl[id],self.eng[id]): |
|
passed_LTL.write(self.ltl[id]) |
|
passed_ENG.write(self.eng[id]) |
|
else: |
|
unpassed_row.write("{}\n".format(id)) |
|
unpassed_LTL.write(self.ltl[id]) |
|
unpassed_ENG.write(self.eng[id]) |
|
|
|
if __name__=="__main__": |
|
|
|
DataPreprocess().JSONdataCreate() |
|
|