avsolatorio commited on
Commit
96070b5
·
1 Parent(s): 2b1e4b7

Add wbgtopic

Browse files

Signed-off-by: avsolatorio <[email protected]>

Files changed (3) hide show
  1. app.py +8 -3
  2. requirements.txt +2 -0
  3. wbgtopic.py +98 -0
app.py CHANGED
@@ -1,7 +1,12 @@
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
7
  demo.launch()
 
1
  import gradio as gr
2
+ import wbgtopic
3
 
4
+ clf = wbgtopic.WBGDocTopic()
 
5
 
6
+
7
+ def fn(inputs):
8
+ return clf.suggest_topics(inputs)
9
+
10
+
11
+ demo = gr.Interface(fn=clf.suggest_topics, inputs="text", outputs=gr.JSON(label="Suggested topics"))
12
  demo.launch()
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ nltk
2
+ transformers[torch]
wbgtopic.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from tqdm.auto import tqdm
3
+ import pandas as pd
4
+ from transformers import AutoTokenizer
5
+ import nltk
6
+
7
+ # Download the nltk data if not present
8
+ nltk.download('punkt_tab')
9
+ nltk.download('punkt')
10
+
11
+
12
+ class WBGDocTopic:
13
+ """
14
+ A class to handle document topic suggestion using multiple pre-trained text classification models.
15
+
16
+ This class loads a set of text classification models from Hugging Face's model hub and
17
+ provides a method to suggest topics for input documents based on the aggregated classification
18
+ results from all the models.
19
+
20
+ Attributes:
21
+ -----------
22
+ classifiers : dict
23
+ A dictionary mapping model names to corresponding classification pipelines. It holds
24
+ instances of Hugging Face's `pipeline` used for text classification.
25
+
26
+ Methods:
27
+ --------
28
+ __init__(classifiers: dict = None)
29
+ Initializes the `WBGDocTopic` instance. If no classifiers are provided, it loads a default
30
+ set of classifiers by calling `load_classifiers`.
31
+
32
+ load_classifiers()
33
+ Loads a predefined set of document topic classifiers into the `classifiers` dictionary.
34
+ It uses `tqdm` to display progress as the classifiers are loaded.
35
+
36
+ suggest_topics(input_docs: str | list[str]) -> list
37
+ Suggests topics for the given document or list of documents. It runs each document
38
+ through all classifiers, averages their scores, and returns a list of dictionaries where each
39
+ dictionary contains the mean and standard deviation of the topic scores per document.
40
+
41
+ Parameters:
42
+ -----------
43
+ input_docs : str or list of str
44
+ A single document or a list of documents for which to suggest topics.
45
+
46
+ Returns:
47
+ --------
48
+ list
49
+ A list of dictionaries, where each dictionary represents the suggested topics for
50
+ each document, along with the mean and standard deviation of the topic classification scores.
51
+ """
52
+
53
+ def __init__(self, classifiers: dict = None, device: str = None):
54
+ self.classifiers = classifiers or {}
55
+ self.device = device
56
+
57
+ if classifiers is None:
58
+ self.load_classifiers()
59
+
60
+ def load_classifiers(self):
61
+ num_evals = 5
62
+ num_train = 5
63
+
64
+ tokenizer = AutoTokenizer.from_pretrained("avsolatorio/doc-topic-model_eval-04_train-03")
65
+
66
+ for i in tqdm(range(num_evals)):
67
+ for j in tqdm(range(num_train)):
68
+ if i == j:
69
+ continue
70
+
71
+ model_name = f"avsolatorio/doc-topic-model_eval-{i:02}_train-{j:02}"
72
+ classifier = pipeline("text-classification", model=model_name, tokenizer=tokenizer, top_k=None, device=self.device)
73
+
74
+ self.classifiers[model_name] = classifier
75
+
76
+ def suggest_topics(self, input_docs: str | list[str]):
77
+ if isinstance(input_docs, str):
78
+ input_docs = [input_docs]
79
+
80
+ doc_outs = {i: [] for i in range(len(input_docs))}
81
+ topics = []
82
+
83
+ for _, classifier in self.classifiers.items():
84
+ for doc_idx, doc in enumerate(classifier(input_docs)):
85
+ doc_outs[doc_idx].append(pd.DataFrame.from_records(doc, index="label"))
86
+
87
+ for doc_idx, outs in doc_outs.items():
88
+ all_scores = pd.concat(outs, axis=1)
89
+ mean_probs = all_scores.mean(axis=1).sort_values(ascending=False)
90
+ std_probs = all_scores.std(axis=1).loc[mean_probs.index]
91
+ output = pd.DataFrame({"score_mean": mean_probs, "score_std": std_probs})
92
+
93
+ output["doc_idx"] = doc_idx
94
+ output.reset_index(inplace=True)
95
+
96
+ topics.append(output.to_dict(orient="records"))
97
+
98
+ return topics