exbert / server /data_processing /create_hdf5.py
bhoov's picture
First commit
63858e7
raw
history blame
2.81 kB
import numpy as np
import torch
import h5py
import pickle
import argparse
from pathlib import Path
from data_processing.sentence_extracting import extract_chars, extract_lines
from data_processing.corpus_data_wrapper import CorpusDataWrapper, to_idx
from transformer_details import from_pretrained
MIN_SENTENCE_CHARLEN = 24
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-f", "--file", help="Path to .pckl file of unique sentences from a corpus.")
parser.add_argument("-o", "--outdir", help="Path of directory in which to store the analyzed sentences as a .hdf5")
parser.add_argument("-m", "--model", default="bert-base-cased", help="Which pretrained transformer model to use. See 'transformer_details.py' for supported models")
parser.add_argument("--nomask", action='store_false', help="By default, ignore attentions to special tokens like '[CLS]' and '[SEP]'. If given, include these attentions")
parser.add_argument("--force", action="store_true", help="If given, overwrite existing hdf5 files.")
args = parser.parse_args()
return args
def main(infile, outdir, force, model_name, mask_attentions):
outdir = Path(outdir)
outdir.mkdir(parents=True, exist_ok=True)
data_outfile = outdir / "data.hdf5"
f = h5py.File(data_outfile, 'a')
if force: f.clear()
extractor = from_pretrained(model_name)
# if "gpt" in model_name:
# mask_attentions = False
print_every = 50
long_strings = extract_chars(infile, 10000)
cutoff_sent = ""
i = 0
for strip in long_strings:
sentences = [sent.text for sent in extractor.aligner.spacy_nlp(strip).sents]
fixed_sentences = [cutoff_sent + sentences[0]] + sentences[1:-1]
# This leads to the possibility that there will be an input that is two sentences long. This is ok.
cutoff_sent = sentences[-1]
for s in fixed_sentences:
if len(s) < MIN_SENTENCE_CHARLEN: continue
if ((i + 1) % print_every) == 0: print(f"Starting sentence {i+1}: \n", s)
try:
out = extractor.att_from_sentence(s, mask_attentions=mask_attentions)
except Exception as e:
print(f"Error {e} occured at sentence {i}:\n{s}\n\n Skipping, not creating hdf5 grp")
continue
content = out.to_hdf5_content()
meta = out.to_hdf5_meta()
grp = f.create_group(to_idx(i))
for k,v in content.items(): grp.create_dataset(k, data=v)
for k, v in meta.items(): grp.attrs[k] = v
i += 1 # Increment to mark the next sentence
print("FINISHED CORPUS PROCESSING SUCCESSFULLY")
if __name__ == "__main__":
args = parse_args()
main(args.file, args.outdir, args.force, args.model, args.nomask)