File size: 7,409 Bytes
cea575d
 
 
 
 
 
 
1cdf555
 
 
 
 
6abecd2
af0a76f
b4e33fb
1cdf555
cea575d
 
864ec1b
1cdf555
 
cea575d
 
 
 
 
 
 
 
1cdf555
 
cea575d
1cdf555
cea575d
1cdf555
b4e33fb
cea575d
1cdf555
cea575d
 
 
 
b4e33fb
af0a76f
 
1cdf555
 
cea575d
1cdf555
 
 
 
 
 
 
 
cea575d
 
 
 
 
 
 
1cdf555
 
 
 
 
 
cea575d
 
 
1cdf555
 
cea575d
1cdf555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a6e249
1cdf555
 
 
 
af0a76f
cea575d
 
 
 
af0a76f
 
 
 
 
cea575d
af0a76f
 
cea575d
 
 
 
 
 
 
af0a76f
 
 
 
 
 
 
 
 
 
 
 
cea575d
 
 
 
 
af0a76f
 
 
 
 
 
604d17a
cea575d
af0a76f
 
 
604d17a
cea575d
af0a76f
 
 
 
 
 
 
1cdf555
 
 
cea575d
1cdf555
 
 
af0a76f
 
1cdf555
 
 
 
 
 
 
 
 
 
 
cea575d
 
 
 
 
 
 
 
 
1cdf555
af0a76f
1cdf555
cea575d
 
 
 
 
 
 
 
1cdf555
cea575d
 
 
 
 
 
 
 
 
1cdf555
cea575d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import argparse
import json as js
import os
import re
from pathlib import Path
from typing import List, Tuple

import fasttext
import gradio as gr
import joblib
import omikuji
from huggingface_hub import snapshot_download
from prepare_everything import download_model



download_model(
    "https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin",
    Path("lid.176.bin"))

# Download the model files from Hugging Face

model_names = [
    "omikuji-bonsai-parliament-spacy-de-all_topics-input_long",
    "omikuji-bonsai-parliament-spacy-fr-all_topics-input_long",
    "omikuji-bonsai-parliament-spacy-it-all_topics-input_long",
]

for repo_id in model_names:
    if not os.path.exists(repo_id):
        os.makedirs(repo_id)
    model_dir = snapshot_download(repo_id=f"kapllan/{repo_id}", local_dir=f"kapllan/{repo_id}")

lang_model = fasttext.load_model("lid.176.bin")

with open(Path("label2id.json"), "r") as f:
    label2id = js.load(f)

id2label = {}

for key, value in label2id.items():
    id2label[str(value)] = key
with open(Path("topics_hierarchy.json"), "r") as f:
    topics_hierarchy = js.load(f)


def map_language(language: str) -> str:
    language_mapping = {"de": "German", "it": "Italian", "fr": "French"}
    if language in language_mapping.keys():
        return language_mapping[language]
    else:
        return language


def find_model(language: str):
    vectorizer, model = None, None
    if language in ["de", "fr", "it"]:
        path_to_vectorizer = (
            f"./kapllan/omikuji-bonsai-parliament-spacy-{language}-all_topics-input_long/vectorizer"
        )
        path_to_model = (
            f"./kapllan/omikuji-bonsai-parliament-spacy-{language}-all_topics-input_long/omikuji-model"
        )
        vectorizer = joblib.load(path_to_vectorizer)
        model = omikuji.Model.load(path_to_model)
    return vectorizer, model


def predict_lang(text: str) -> str:
    text = re.sub(
        r"\n", "", text
    )  # Remove linebreaks because fasttext cannot process that otherwise
    predictions = lang_model.predict(text, k=1)  # returns top 2 matching languages
    language = predictions[0][0]  # returns top 2 matching languages
    language = re.sub(r"__label__", "", language)  # returns top 2 matching languages
    return language


def predict_topic(text: str) -> [List[str], str]:
    results = []
    language = predict_lang(text)
    vectorizer, model = find_model(language)
    language = map_language(language)
    if vectorizer is not None:
        texts = [text]
        vector = vectorizer.transform(texts)
        for row in vector:
            if row.nnz == 0:  # All zero vector, empty result
                continue
            feature_values = [(col, row[0, col]) for col in row.nonzero()[1]]
            for subj_id, score in model.predict(feature_values, top_k=1000):
                score = round(score, 2)
                results.append((id2label[str(subj_id)], score))
    return results, language


def get_row_color(type: str):
    if "main" in type.lower():
        return "background-color: darkgrey;"
    if "sub" in type.lower():
        return "background-color: lightgrey;"


def generate_html_table(topics: List[Tuple[str, str, float]]):
    html = '<table style="width:100%; border: 1px solid black; border-collapse: collapse;">'

    html += "<tr><th>Type</th><th>Topic</th><th>Score</th></tr>"
    for type, topic, score in topics:
        color = get_row_color(type)
        topic = f"<strong>{topic}</strong>" if "main" in type.lower() else topic
        type = f"<strong>{type}</strong>" if "main" in type.lower() else type
        score = f"<strong>{score}</strong>" if "main" in type.lower() else score
        html += (
            f'<tr style="{color}"><td>{type}</td><td>{topic}</td><td>{score}</td></tr>'
        )
    html += "</table>"
    return html


def restructure_topics(topics: List[Tuple[str, float]]) -> List[Tuple[str, str, float]]:
    topics = [(str(x[0]).lower(), x[1]) for x in topics]
    topics_as_dict = {}
    for predicted_topic, score in topics:
        if str(predicted_topic).lower() in topics_hierarchy.keys():
            topics_as_dict[str(predicted_topic).lower()] = []

    for predicted_topic, score in topics:
        for main_topic, sub_topics in topics_hierarchy.items():
            if (
                    main_topic in topics_as_dict.keys()
                    and predicted_topic != main_topic
                    and predicted_topic in sub_topics
            ):
                topics_as_dict[main_topic].append(predicted_topic)

    topics_restructured = []
    for predicted_main_topic, predicted_sub_topics in topics_as_dict.items():
        if len(predicted_sub_topics) > 0:
            score = [t for t in topics if t[0] == predicted_main_topic][0][1]
            predicted_main_topic = predicted_main_topic.replace("hauptthema: ", "")
            topics_restructured.append(("Main Topic", predicted_main_topic, score))
            predicted_sub_topics_with_scores = []
            for pst in predicted_sub_topics:
                score = [t for t in topics if t[0] == pst][0][1]
                pst = pst.replace("unterthema: ", "")
                entry = ("Sub Topic", pst, score)
                if entry not in predicted_sub_topics_with_scores:
                    predicted_sub_topics_with_scores.append(entry)
            for x in predicted_sub_topics_with_scores:
                topics_restructured.append(x)
    return topics_restructured


def topic_modeling(text: str, threshold: float) -> [List[str], str]:
    # Prepare labels and scores for the plot
    sorted_topics, language = predict_topic(text)
    if len(sorted_topics) > 0 and language in ["German", "French", "Italian"]:
        sorted_topics = [t for t in sorted_topics if t[1] >= threshold]
    else:
        sorted_topics = []
    sorted_topics = restructure_topics(sorted_topics)
    sorted_topics = generate_html_table(sorted_topics)
    return sorted_topics, language


with gr.Blocks() as iface:
    gr.Markdown("# Topic Modeling")
    gr.Markdown("Enter a document and get each topic along with its score.")

    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(lines=10, placeholder="Enter a document")
            submit_button = gr.Button("Submit")
            threshold_slider = gr.Slider(
                minimum=0.0, maximum=1.0, step=0.01, label="Score Threshold", value=0.0
            )
            language_text = gr.Textbox(
                lines=1,
                placeholder="Detected language will be shown here...",
                interactive=False,
                label="Detected Language",
            )
        with gr.Column():
            output_data = gr.HTML()

    submit_button.click(
        topic_modeling,
        inputs=[input_text, threshold_slider],
        outputs=[output_data, language_text],
    )

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "-ipa",
        "--ip_address",
        default=None,
        type=str,
        help="Specify the IP address of your computer.",
    )

    args = parser.parse_args()
    # Launch the app
    if args.ip_address is None:
        _, public_url = iface.launch(share=True)
        print(f"The app runs here: {public_url}")

    else:
        iface.launch(server_name=args.ip_address, server_port=8080, show_error=True)