Spaces:
Sleeping
Sleeping
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
from typing import List, Dict | |
from .topic_list import TOPIC_LIST | |
class TopicAnalyzer: | |
def __init__(self): | |
self.device = "cpu" | |
self.model_name = "facebook/bart-large-mnli" | |
self.tokenizer = None | |
self.model = None | |
self.classifier = None | |
self.max_length = 1024 | |
self.topic_hierarchy = TOPIC_LIST | |
self.set_classifier() | |
def set_device(self, device: str): | |
if device != self.device: | |
self.device = device | |
self.set_classifier() | |
def set_classifier(self): | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, | |
use_fast=True | |
) | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
self.model_name | |
).to(self.device) | |
# Set zero-shot pipeline | |
self.classifier = pipeline( | |
"zero-shot-classification", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
device=self.device | |
) | |
except Exception as e: | |
print(f"Error initializing classifier: {str(e)}") | |
raise | |
async def generate_topics(self, text: str, category: str, subcategory: str) -> List[Dict]: | |
try: | |
all_topics = [] | |
for subcat in self.topic_hierarchy[category].values(): | |
all_topics.extend(subcat) | |
result = self.classifier( | |
text[:self.max_length], | |
all_topics, | |
multi_label=True | |
) | |
topics = [ | |
{"topic": topic, "score": score} | |
for topic, score in zip(result["labels"], result["scores"]) | |
if score > 0.1 | |
] | |
return sorted(topics, key=lambda x: x["score"], reverse=True)[:10] | |
except Exception as e: | |
print(f"Error generating topics: {str(e)}") | |
return [] |