File size: 1,723 Bytes
b8d16b2
59135d9
b8d16b2
 
 
 
59135d9
 
 
b8d16b2
 
 
 
 
59135d9
 
b8d16b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59135d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8d16b2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import csv
import random

import spacy
import srsly
import tqdm
import yaml

params = yaml.safe_load(open("params.yaml"))

nlp = spacy.load("en_core_web_trf")

INPUT_FILE = "data/processed/wellcome_grant_descriptions.csv"
OUTPUT_FILE = "data/processed/entities.jsonl"
INCLUDE_ENTS = {"GPE", "LOC"}
EXCLUDE_ENTS = {"PERSON"}


def process_documents(input_file: str, output_file: str):

    data = []

    print(f"Reading data from {input_file}...")

    with open(input_file, "r") as f:
        reader = csv.reader(f)
        next(reader)

        for row in reader:
            data.append(row[0])

    print(f"Processing {len(data)} documents...")

    entities = []

    for doc_ in tqdm.tqdm(data):
        doc = nlp(doc_)

        # Get a list of found entities

        ents = [
            {
                "text": ent.text,
                "label": ent.label_,
                "start": ent.start_char,
                "end": ent.end_char,
            }
            for ent in doc.ents
        ]

        if ents:
            found_ents = set([ent["label"] for ent in ents])

            if found_ents.intersection(INCLUDE_ENTS) and not found_ents.intersection(
                EXCLUDE_ENTS
            ):
                entities.append(
                    {
                        "text": doc.text,
                        "ents": ents,
                    }
                )

    print(f"Randomly selecting {params['max_docs']} documents...")

    random.shuffle(entities)
    entities = entities[: params["max_docs"]]

    print(f"Writing {len(entities)} documents to {output_file}...")

    srsly.write_jsonl(output_file, entities)


if __name__ == "__main__":
    process_documents(INPUT_FILE, OUTPUT_FILE)