Spaces:
Sleeping
Sleeping
from transformers import pipeline | |
from typing import List | |
try: | |
import torch | |
print(f"Is CUDA available: {torch.cuda.is_available()}") | |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
device = 0 | |
except: | |
print("No GPU available, running on CPU") | |
device = None | |
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device) | |
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device) | |
label_map = { | |
"something else": "non-civic", | |
"headlines, news channels, news articles, breaking news": "news", | |
"politics, policy and politicians": "politics", | |
"health are and public health": "health", | |
"religious": "news" # CONSCIOUS DECISION | |
} | |
default_label = "something else" | |
def map_scores(predicted_labels: List[dict], default_label: str): | |
mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels] | |
return mapped_scores | |
def get_first_relevant_label(predicted_labels, mapped_scores: List[float], default_label: str): | |
for i, value in enumerate(mapped_scores): | |
if value != 0: | |
return label_map[predicted_labels[i]['labels'][0]], i | |
return label_map[default_label], i # Return if all values are zero or the list is empty | |
def classify(texts: List[str], labels: List[str]): | |
predicted_labels = model(texts, labels, multi_label=False, batch_size=16) | |
print(predicted_labels) | |
return predicted_labels | |
def classify(texts: List[str], labels: List[str]): | |
results = [] | |
# Lists to hold texts and indices for model processing | |
model_texts = [] | |
model_indices = [] | |
# Iterate through each text to check for special cases | |
for index, text in enumerate(texts): | |
if text == "NON-VALID": | |
print("NON-VALID TEXT!!", text) | |
# If text is "X", directly assign the label and score | |
results.append({ | |
"sequence": text, | |
"labels": [default_label], # Assuming the first label is the correct one for "X" | |
"scores": [1.0] # Assign a full score | |
}) | |
else: | |
# Otherwise, prepare for model processing | |
#print("- text =>", text) | |
model_texts.append(text) | |
model_indices.append(index) | |
if model_texts: | |
# Process texts through the model if there are any | |
predicted_labels = model(model_texts, labels, multi_label=False, batch_size=32) | |
# Insert model results into the correct positions | |
for pred, idx in zip(predicted_labels, model_indices): | |
results.insert(idx, pred) | |
print([(r['labels'][0], r['sequence']) for r in results]) | |
return results |