Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
from transformers import BertForTokenClassification, RobertaForTokenClassification, AutoModelForTokenClassification | |
def build_model(args): | |
if args.corpus == "chemu": | |
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 21, cache_dir = args.cache_dir, return_dict = False) | |
elif args.corpus == "chemdner": | |
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 17, cache_dir = args.cache_dir, return_dict = False) | |
elif args.corpus == "chemdner-mol": | |
return AutoModelForTokenClassification.from_pretrained(args.roberta_checkpoint, num_labels = 3, cache_dir = args.cache_dir, return_dict = False) | |