File size: 2,069 Bytes
545ce24
 
 
 
 
 
 
4de9003
545ce24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from typing import List, Dict
from .topic_list import TOPIC_LIST

class TopicAnalyzer:
    def __init__(self):
        self.device = "cpu"
        self.model_name = "facebook/bart-large-mnli"
        self.tokenizer = None
        self.model = None
        self.classifier = None
        self.max_length = 1024
        self.topic_hierarchy = TOPIC_LIST
        self.set_classifier()

    def set_device(self, device: str):
        if device != self.device:
            self.device = device
            self.set_classifier()

    def set_classifier(self):
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                use_fast=True
            )
            self.model = AutoModelForSequenceClassification.from_pretrained(
                self.model_name
            ).to(self.device)

            # Set zero-shot pipeline
            self.classifier = pipeline(
                "zero-shot-classification",
                model=self.model,
                tokenizer=self.tokenizer,
                device=self.device
            )
        except Exception as e:
            print(f"Error initializing classifier: {str(e)}")
            raise

    async def generate_topics(self, text: str, category: str, subcategory: str) -> List[Dict]:
        try:
            all_topics = []
            for subcat in self.topic_hierarchy[category].values():
                all_topics.extend(subcat)

            result = self.classifier(
                text[:self.max_length],
                all_topics,
                multi_label=True
            )

            topics = [
                {"topic": topic, "score": score}
                for topic, score in zip(result["labels"], result["scores"])
                if score > 0.1
            ]

            return sorted(topics, key=lambda x: x["score"], reverse=True)[:10]

        except Exception as e:
            print(f"Error generating topics: {str(e)}")
            return []