File size: 2,813 Bytes
63858e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import torch
import h5py
import pickle
import argparse
from pathlib import Path

from data_processing.sentence_extracting import extract_chars, extract_lines
from data_processing.corpus_data_wrapper import CorpusDataWrapper, to_idx
from transformer_details import from_pretrained

MIN_SENTENCE_CHARLEN = 24

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--file", help="Path to .pckl file of unique sentences from a corpus.")
    parser.add_argument("-o", "--outdir", help="Path of directory in which to store the analyzed sentences as a .hdf5")
    parser.add_argument("-m", "--model", default="bert-base-cased", help="Which pretrained transformer model to use. See 'transformer_details.py' for supported models")
    parser.add_argument("--nomask", action='store_false', help="By default, ignore attentions to special tokens like '[CLS]' and '[SEP]'. If given, include these attentions")
    parser.add_argument("--force", action="store_true", help="If given, overwrite existing hdf5 files.")

    args = parser.parse_args()
    return args

def main(infile, outdir, force, model_name, mask_attentions):
    outdir = Path(outdir)
    outdir.mkdir(parents=True, exist_ok=True)
    data_outfile = outdir / "data.hdf5"
    f = h5py.File(data_outfile, 'a')
    if force: f.clear()

    extractor = from_pretrained(model_name)

    # if "gpt" in model_name:
    #     mask_attentions = False

    print_every = 50
    long_strings = extract_chars(infile, 10000)
    cutoff_sent = ""
    i = 0
    for strip in long_strings:
        sentences = [sent.text for sent in extractor.aligner.spacy_nlp(strip).sents]
        fixed_sentences = [cutoff_sent + sentences[0]] + sentences[1:-1]

        # This leads to the possibility that there will be an input that is two sentences long. This is ok.
        cutoff_sent = sentences[-1]
        for s in fixed_sentences:
            if len(s) < MIN_SENTENCE_CHARLEN: continue
            if ((i + 1) % print_every) == 0: print(f"Starting sentence {i+1}: \n", s)

            try:
                out = extractor.att_from_sentence(s, mask_attentions=mask_attentions)

            except Exception as e:
                print(f"Error {e} occured at sentence {i}:\n{s}\n\n Skipping, not creating hdf5 grp")
                continue

            content = out.to_hdf5_content()
            meta = out.to_hdf5_meta()
            grp = f.create_group(to_idx(i))
            for k,v in content.items(): grp.create_dataset(k, data=v)
            for k, v in meta.items(): grp.attrs[k] = v

            i += 1 # Increment to mark the next sentence

    print("FINISHED CORPUS PROCESSING SUCCESSFULLY")

if __name__ == "__main__":
    args = parse_args()

    main(args.file, args.outdir, args.force, args.model, args.nomask)