File size: 3,114 Bytes
9df4cc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from seqeval.metrics import classification_report
from datasets import load_dataset, load_from_disk
from tqdm import tqdm
import datasets
import torch
from torch.utils.data import DataLoader
from functools import partial
import re
import sys
import numpy as np
from fingpt.FinGPT_Benchmark.utils import *
from pathlib import Path
sys.path.append('../')

ent_dict = {
    'PER': 'person',
    'ORG': 'organization',
    'LOC': 'location',
}
ent_dict_rev = {v: k for k, v in ent_dict.items()}


def cvt_text_to_pred(tokens, text):
    
    preds = ['O' for _ in range(len(tokens))]
    for pred_txt in text.lower().strip('.').split(','):
    
        pred_match = re.match(r'^(.*) is an? (.*)$', pred_txt)
        if pred_match is not None:
            entity, entity_type = pred_match.group(1).strip(), pred_match.group(2).strip()
            entity_pred = ent_dict_rev.get(entity_type, 'O')
            entity_tokens = entity.split()

            n = len(entity_tokens)
            for i in range(len(tokens) - n + 1):
                if tokens[i:i+n] == entity_tokens and preds[i:i+n] == ['O'] * n:
                    preds[i:i+n] = ['B-' + entity_pred] + ['I-' + entity_pred] * (n-1)
                    break
        else:
            print(pred_txt)
            
    return preds


def map_output(feature):

    tokens = feature['input'].lower().split()
    label = cvt_text_to_pred(tokens, feature['output'])
    pred = cvt_text_to_pred(tokens, feature['out_text'])
    
    return {'label': label, 'pred': pred}


def test_ner(args, model, tokenizer):

    dataset = load_from_disk(Path(__file__).parent.parent / 'data/fingpt-ner')['test']
    dataset = dataset.map(partial(test_mapping, args), load_from_cache_file=False)
    
    def collate_fn(batch):
        inputs = tokenizer(
            [f["prompt"] for f in batch], return_tensors='pt',
            padding=True, max_length=args.max_length,
            return_token_type_ids=False
        )
        return inputs
    
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)
    
    out_text_list = []
    log_interval = len(dataloader) // 5

    for idx, inputs in enumerate(tqdm(dataloader)):
        inputs = {key: value.to(model.device) for key, value in inputs.items()}
        res = model.generate(**inputs, max_length=args.max_length, eos_token_id=tokenizer.eos_token_id)
        res_sentences = [tokenizer.decode(i, skip_special_tokens=True) for i in res]
        if (idx + 1) % log_interval == 0:
            tqdm.write(f'{idx}: {res_sentences[0]}')
        out_text = [o.split("Answer: ")[1] for o in res_sentences]
        out_text_list += out_text
        torch.cuda.empty_cache()
    
    dataset = dataset.add_column("out_text", out_text_list)
    dataset = dataset.map(map_output, load_from_cache_file=False)    
    dataset = dataset.to_pandas()
    
    print(dataset)
    dataset.to_csv('tmp.csv')
    
    label = [d.tolist() for d in dataset['label']]
    pred = [d.tolist() for d in dataset['pred']]
    
    print(classification_report(label, pred, digits=4))

    return dataset