Model Description

The model classifies whether two appraisals aligned or not and is trained on ALOE dataset.

Input: two appraisals (see forward function in SNN class)

Output: cosine similarity

Model architecture: Siamese Network + all-mpnet-base-v2

Developed by: Jiamin Yang

Model Performance

F1 Recall Precision
0.46 0.45 0.46

Getting Started

import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer

class SNN(nn.Module): 
    def __init__(self, model_name):
        super(SNN,self).__init__()
        self.model = AutoModel.from_pretrained(model_name).to("cuda").train()
        self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-4)
        
    def mean_pooling(self, token_embeddings, attention_mask): 
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    def forward(self, input_ids_a, attention_a, input_ids_b, attention_b): 
        #encode sentence and get mean pooled sentence representation 
        encoding1 = self.model(input_ids_a, attention_mask=attention_a)[0] #all token embeddings
        encoding2 = self.model(input_ids_b, attention_mask=attention_b)[0]
        
        meanPooled1 = self.mean_pooling(encoding1, attention_a)
        meanPooled2 = self.mean_pooling(encoding2, attention_b)
        
        pred = self.cos(meanPooled1, meanPooled2)
        return pred

checkpoint_path = 'your_path_to/empathy-appraisal-alignment.pt'

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = SNN('sentence-transformers/all-mpnet-base-v2').to('cuda')
checkpoint = torch.load(checkpoint_path)
state_dict = checkpoint['model_state_dict']

# depend on the version of torch
del state_dict['model.embeddings.position_ids']

model.load_state_dict(state_dict)

# use the model
target = ["I'm so sad that my cat died yesterday."]
observer = ["It's ok to feel sad."]

target_encodings = tokenizer(target, padding=True, truncation=True)
target_input_ids = torch.LongTensor(target_encodings['input_ids']).to('cuda')
target_attention_mask = torch.LongTensor(target_encodings['attention_mask']).to('cuda')
observer_encodings = tokenizer(observer, padding=True, truncation=True)
observer_input_ids = torch.LongTensor(observer_encodings['input_ids']).to('cuda')
observer_attention_mask = torch.LongTensor(observer_encodings['attention_mask']).to('cuda')

model.eval()
output = model(target_input_ids, target_attention_mask, observer_input_ids, observer_attention_mask)
print(output) # [0.5755]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Dataset used to train Blablablab/empathy-appraisal-alignment