File size: 2,787 Bytes
86e971c
3b86501
 
c8df78e
 
 
 
 
 
 
 
3b86501
8ad2ef4
c8df78e
3b86501
 
 
 
 
 
 
 
c8df78e
3b86501
 
 
 
 
 
 
 
 
 
 
 
 
c8df78e
3b86501
 
c8df78e
 
 
 
 
 
 
 
 
 
 
 
8ad2ef4
c8df78e
 
 
 
 
 
 
 
8ad2ef4
c8df78e
 
 
 
 
8ad2ef4
c8df78e
 
 
 
8ad2ef4
c8df78e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import pipeline
from typing import List

try:
    import torch
    print(f"Is CUDA available: {torch.cuda.is_available()}")
    print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    device = 0
except:
    print("No GPU available, running on CPU")
    device = None

#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)

label_map = {
    "something else": "non-civic",
    "headlines, news channels, news articles, breaking news": "news",
    "politics, policy and politicians": "politics",
    "health are and public health": "health",
    "religious": "news" # CONSCIOUS DECISION
}
default_label = "something else"

def map_scores(predicted_labels: List[dict], default_label: str):
    mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels]
    return mapped_scores

def get_first_relevant_label(predicted_labels, mapped_scores: List[float], default_label: str):
    for i, value in enumerate(mapped_scores):
        if value != 0:
            return label_map[predicted_labels[i]['labels'][0]], i
    return label_map[default_label], i  # Return if all values are zero or the list is empty


def classify(texts: List[str], labels: List[str]):
    predicted_labels = model(texts, labels, multi_label=False, batch_size=16)
    print(predicted_labels)
    return predicted_labels
    

def classify(texts: List[str], labels: List[str]):
    results = []
    
    # Lists to hold texts and indices for model processing
    model_texts = []
    model_indices = []
    
    # Iterate through each text to check for special cases
    for index, text in enumerate(texts):
        if text == "NON-VALID":
            print("NON-VALID TEXT!!", text)
            # If text is "X", directly assign the label and score
            results.append({
                "sequence": text,
                "labels": [default_label],  # Assuming the first label is the correct one for "X"
                "scores": [1.0]  # Assign a full score
            })
        else:
            # Otherwise, prepare for model processing
            #print("- text =>", text)
            model_texts.append(text)
            model_indices.append(index)
    
    if model_texts:
        # Process texts through the model if there are any
        predicted_labels = model(model_texts, labels, multi_label=False, batch_size=32)
        
        # Insert model results into the correct positions
        for pred, idx in zip(predicted_labels, model_indices):
            results.insert(idx, pred)
    print([(r['labels'][0], r['sequence']) for r in results])
    return results