Spaces:
Sleeping
Sleeping
File size: 2,787 Bytes
86e971c 3b86501 c8df78e 3b86501 8ad2ef4 c8df78e 3b86501 c8df78e 3b86501 c8df78e 3b86501 c8df78e 8ad2ef4 c8df78e 8ad2ef4 c8df78e 8ad2ef4 c8df78e 8ad2ef4 c8df78e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from transformers import pipeline
from typing import List
try:
import torch
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
device = 0
except:
print("No GPU available, running on CPU")
device = None
#model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", device=device)
model = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-9", device=device)
label_map = {
"something else": "non-civic",
"headlines, news channels, news articles, breaking news": "news",
"politics, policy and politicians": "politics",
"health are and public health": "health",
"religious": "news" # CONSCIOUS DECISION
}
default_label = "something else"
def map_scores(predicted_labels: List[dict], default_label: str):
mapped_scores = [item['scores'][0] if item['labels'][0]!= default_label else 0 for item in predicted_labels]
return mapped_scores
def get_first_relevant_label(predicted_labels, mapped_scores: List[float], default_label: str):
for i, value in enumerate(mapped_scores):
if value != 0:
return label_map[predicted_labels[i]['labels'][0]], i
return label_map[default_label], i # Return if all values are zero or the list is empty
def classify(texts: List[str], labels: List[str]):
predicted_labels = model(texts, labels, multi_label=False, batch_size=16)
print(predicted_labels)
return predicted_labels
def classify(texts: List[str], labels: List[str]):
results = []
# Lists to hold texts and indices for model processing
model_texts = []
model_indices = []
# Iterate through each text to check for special cases
for index, text in enumerate(texts):
if text == "NON-VALID":
print("NON-VALID TEXT!!", text)
# If text is "X", directly assign the label and score
results.append({
"sequence": text,
"labels": [default_label], # Assuming the first label is the correct one for "X"
"scores": [1.0] # Assign a full score
})
else:
# Otherwise, prepare for model processing
#print("- text =>", text)
model_texts.append(text)
model_indices.append(index)
if model_texts:
# Process texts through the model if there are any
predicted_labels = model(model_texts, labels, multi_label=False, batch_size=32)
# Insert model results into the correct positions
for pred, idx in zip(predicted_labels, model_indices):
results.insert(idx, pred)
print([(r['labels'][0], r['sequence']) for r in results])
return results |