Model Description

The model classifies an appraisal given a sentence and is trained on ALOE dataset.

Input: a sentence

Labels: No Label, Pleasantness, Anticipated Effort, Certainty, Objective Experience, Self-Other Agency, Situational Control, Advice, Trope

Output: logits (in order of labels)

Model architecture: OpenPrompt_+RoBERTa

Developed by: Jiamin Yang

Model Performance

Overall performance
Macro-F1 Recall Precision
0.56 0.57 0.58
Per-label performance
Label Recall Precision
No Label 0.34 0.64
Pleasantness 0.69 0.54
Anticipated Effort 0.46 0.46
Certainty 0.58 0.47
Objective Experience 0.58 0.69
Self-Other Agency 0.62 0.55
Situational Control 0.31 0.55
Advice 0.72 0.66
Trope 0.80 0.67

Getting Started

import torch
from openprompt.plms import load_plm
from openprompt.prompts import ManualTemplate
from openprompt.prompts import ManualVerbalizer
from openprompt import PromptForClassification
from openprompt.data_utils import InputExample
from openprompt import PromptDataLoader

checkpoint_file = 'your_path_to/empathy-appraisal-span.pt'

plm, tokenizer, model_config, WrapperClass = load_plm('roberta', 'roberta-large')
template_text = 'The sentence {"placeholder":"text_a"} has the label {"mask"}.'
template = ManualTemplate(tokenizer=tokenizer, text=template_text)

num_classes = 9
label_words = [['No Label'], ['Pleasantness'], ['Anticipated Effort'], ['Certainty'], ['Objective Experience'], ['Self-Other Agency'], ['Situational Control'], ['Advice'], ['Trope']]
verbalizer = ManualVerbalizer(tokenizer, num_classes=num_classes, label_words=label_words)
prompt_model = PromptForClassification(plm=plm,template=template, verbalizer=verbalizer, freeze_plm=False).to('cuda')

checkpoint = torch.load(checkpoint_file)
state_dict = checkpoint['model_state_dict']

# depend on the version of torch
del state_dict['prompt_model.plm.roberta.embeddings.position_ids']

prompt_model.load_state_dict(state_dict)

# use the model
dataset = [
    InputExample(
        guid = 0,
        text_a = "I am sorry for your loss",
    ),
    InputExample(
        guid = 1,
        text_a = "It's not your fault",
    ),
]

data_loader = PromptDataLoader(dataset=dataset, 
                template=template, 
                tokenizer=tokenizer,
                tokenizer_wrapper_class=WrapperClass,
                max_seq_length=512,
                batch_size=2,
                shuffle=False,
                teacher_forcing=False,
                predict_eos_token=False,
                truncate_method='head')
prompt_model.eval()
with torch.no_grad():
    for batch in data_loader:
        logits = prompt_model(batch.to('cuda'))
        preds = torch.argmax(logits, dim = -1)
        print(preds) #[8, 5]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Dataset used to train Blablablab/empathy-appraisal-span