Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import numpy as np | |
| import torch | |
| import h5py | |
| import pickle | |
| import argparse | |
| from pathlib import Path | |
| from data_processing.sentence_extracting import extract_chars, extract_lines | |
| from data_processing.corpus_data_wrapper import CorpusDataWrapper, to_idx | |
| from transformer_details import from_pretrained | |
| MIN_SENTENCE_CHARLEN = 24 | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-f", "--file", help="Path to .pckl file of unique sentences from a corpus.") | |
| parser.add_argument("-o", "--outdir", help="Path of directory in which to store the analyzed sentences as a .hdf5") | |
| parser.add_argument("-m", "--model", default="bert-base-cased", help="Which pretrained transformer model to use. See 'transformer_details.py' for supported models") | |
| parser.add_argument("--nomask", action='store_false', help="By default, ignore attentions to special tokens like '[CLS]' and '[SEP]'. If given, include these attentions") | |
| parser.add_argument("--force", action="store_true", help="If given, overwrite existing hdf5 files.") | |
| args = parser.parse_args() | |
| return args | |
| def main(infile, outdir, force, model_name, mask_attentions): | |
| outdir = Path(outdir) | |
| outdir.mkdir(parents=True, exist_ok=True) | |
| data_outfile = outdir / "data.hdf5" | |
| f = h5py.File(data_outfile, 'a') | |
| if force: f.clear() | |
| extractor = from_pretrained(model_name) | |
| # if "gpt" in model_name: | |
| # mask_attentions = False | |
| print_every = 50 | |
| long_strings = extract_chars(infile, 10000) | |
| cutoff_sent = "" | |
| i = 0 | |
| for strip in long_strings: | |
| sentences = [sent.text for sent in extractor.aligner.spacy_nlp(strip).sents] | |
| fixed_sentences = [cutoff_sent + sentences[0]] + sentences[1:-1] | |
| # This leads to the possibility that there will be an input that is two sentences long. This is ok. | |
| cutoff_sent = sentences[-1] | |
| for s in fixed_sentences: | |
| if len(s) < MIN_SENTENCE_CHARLEN: continue | |
| if ((i + 1) % print_every) == 0: print(f"Starting sentence {i+1}: \n", s) | |
| try: | |
| out = extractor.att_from_sentence(s, mask_attentions=mask_attentions) | |
| except Exception as e: | |
| print(f"Error {e} occured at sentence {i}:\n{s}\n\n Skipping, not creating hdf5 grp") | |
| continue | |
| content = out.to_hdf5_content() | |
| meta = out.to_hdf5_meta() | |
| grp = f.create_group(to_idx(i)) | |
| for k,v in content.items(): grp.create_dataset(k, data=v) | |
| for k, v in meta.items(): grp.attrs[k] = v | |
| i += 1 # Increment to mark the next sentence | |
| print("FINISHED CORPUS PROCESSING SUCCESSFULLY") | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args.file, args.outdir, args.force, args.model, args.nomask) | |