socialboost / app /modules /classify.py
ezequiellopez
integration_tests fixes
86e971c
raw
history blame
1.21 kB
from transformers import pipeline
from typing import List
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=0)
label_map = {
"something else": "non-civic",
"headlines, news channels, news articles, breaking news": "news",
"politics, policy and politicians": "politics",
"health are and public health": "health",
"religious": "news" # CONSCIOUS DECISION
}
def map_scores(predicted_labels: List[dict], default_label: str):
mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels]
return mapped_scores
def get_first_relevant_label(predicted_labels, mapped_scores: List[float], default_label: str):
for i, value in enumerate(mapped_scores):
if value != 0:
return label_map[predicted_labels[i]['labels'][0]], i
return label_map[default_label], i # Return if all values are zero or the list is empty
def classify(texts: List[str], labels: List[str]):
predicted_labels = model(texts, labels, multi_label=False)
print(predicted_labels)
return predicted_labels