File size: 805 Bytes
3c77d98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#using pipeline to predict the input text
from transformers import pipeline
import torch

label_mapping = {
    'delete': [0, 'LABEL_0'],
    'keep': [1, 'LABEL_1'],
    'merge': [2, 'LABEL_2'],
    'no consensus': [3, 'LABEL_3'],
    'speedy keep': [4, 'LABEL_4'],
    'speedy delete': [5, 'LABEL_5'],
    'redirect': [6, 'LABEL_6'],
    'withdrawn': [7, 'LABEL_7']
}

def predict_text(text, model_name):
    model = pipeline("text-classification", model=model_name, return_all_scores=True)
    results = model(text)
    final_scores = {key: 0.0 for key in label_mapping}
    
    for result in results[0]:
        for key, value in label_mapping.items():
            if result['label'] == value[1]:
                final_scores[key] = result['score']
                break
    
    return final_scores