File size: 747 Bytes
1f516b6
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
from torch import nn


from transformers import BertForTokenClassification, RobertaForTokenClassification, AutoModelForTokenClassification


def build_model(args):
    if args.corpus == "chemu":
        return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 21, cache_dir = args.cache_dir, return_dict = False)
    elif args.corpus == "chemdner":
        return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 17, cache_dir = args.cache_dir, return_dict = False)
    elif args.corpus == "chemdner-mol":
        return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 3, cache_dir = args.cache_dir, return_dict = False)