File size: 1,206 Bytes
86e971c
3b86501
 
86e971c
 
 
3b86501
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import pipeline
from typing import List

#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)



label_map = {
    "something else": "non-civic",
    "headlines, news channels, news articles, breaking news": "news",
    "politics, policy and politicians": "politics",
    "health are and public health": "health",
    "religious": "news" # CONSCIOUS DECISION
}

def map_scores(predicted_labels: List[dict], default_label: str):
    mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels]
    return mapped_scores

def get_first_relevant_label(predicted_labels, mapped_scores: List[float], default_label: str):
    for i, value in enumerate(mapped_scores):
        if value != 0:
            return label_map[predicted_labels[i]['labels'][0]], i
    return label_map[default_label], i  # Return if all values are zero or the list is empty


def classify(texts: List[str], labels: List[str]):
    predicted_labels = model(texts, labels, multi_label=False)
    print(predicted_labels)
    return predicted_labels