File size: 1,839 Bytes
f584171
 
87e622a
f584171
d98a217
 
 
 
f584171
d98a217
 
f584171
 
 
 
d98a217
 
 
 
f584171
d98a217
f584171
 
 
 
 
 
17962cc
 
f584171
 
87e622a
 
 
d98a217
 
87e622a
 
 
 
 
d98a217
 
f584171
 
 
 
d98a217
 
 
f584171
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from transformers import BertForSequenceClassification, AutoTokenizer
from transformers import PegasusForConditionalGeneration, PegasusTokenizer

# path_emo = 'Djacon/rubert-tiny2-russian-emotion-detection'
path_emo = './models/emotion_detection/'
model_emo = BertForSequenceClassification.from_pretrained(path_emo)
tokenizer_emo = AutoTokenizer.from_pretrained(path_emo)

LABELS = ['Joy', 'Interest', 'Surprise', 'Sadness', 'Anger',
          'Disgust', 'Fear', 'Guilt', 'Neutral']


# Probabilistic prediction of emotion in a text
@torch.no_grad()
def predict_emotions(text: str) -> str:
    inputs = tokenizer_emo(text, max_length=512, truncation=True,
                           return_tensors='pt')
    inputs = inputs.to(model_emo.device)

    outputs = model_emo(**inputs)

    pred = torch.nn.functional.softmax(outputs.logits, dim=1)

    emotions_list = {}
    for i in range(len(pred[0].tolist())):
        emotions_list[LABELS[i]] = round(100 * pred[0].tolist()[i], 3)
    return '\n'.join(f'{k}: {v}' for k, v in sorted(emotions_list.items(),
                                                    key=lambda x: -x[1]))


path_sum = './models/summarizer/'
model_sum = PegasusForConditionalGeneration.from_pretrained(path_sum)
tokenizer_sum = PegasusTokenizer.from_pretrained(path_sum)


def predict_summarization(text: str) -> str:
    batch = tokenizer_sum([text], truncation=True, padding="longest",
                          return_tensors="pt")
    translated = model_sum.generate(**batch)
    return tokenizer_sum.batch_decode(translated, skip_special_tokens=True)[0]


def test():
    predict_emotions('I am so happy now!')
    print('\n>>> Emotion Detection successfully initialized! <<<\n')

    predict_summarization('I am so happy now!')
    print('\n>>> Pegasus successfully initialized! <<<\n')


test()