File size: 4,603 Bytes
1ee5c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import pandas as pd
import streamlit as st
from Functionalities import NLP_Helper
from Functionalities.TopicClustering import TopicClustering
from streamlit_extras.dataframe_explorer import dataframe_explorer

class TopicClusterView:
    def __init__(self):
        self.n_neighbors = 10
        self.topic_cluster = None
        self.representation_model = None
        self.sentence_model = None
        self.text_col = None
        self.text_df = None
        self.text_file = None

        st.session_state.topic_cluster = None \
            if 'topic_cluster' not in st.session_state else st.session_state.topic_cluster

        st.set_page_config(page_title='Topic Clustering', layout="wide")
        st.header("Topic Clustering")
        # st.write(f"This page tries to predict the suitable ad group for new keywords "
        #          f"based on the keywords already existing in the campaign.")

    def input_params(self) -> None:
        """
        Takes csv file input, name of text col, select option for sentence model and representation model
        :return:
        """
        self.text_file = st.file_uploader(label="Upload the CSV file containing the texts to cluster")
        if self.text_file:
            self.text_df = pd.read_csv(self.text_file)
            self.text_col = st.selectbox(
                label=f"Choose the column to use for topic clustering in **{self.text_file.name}**",
                options=self.text_df.columns
            )

            self.sentence_model = st.selectbox(
                label=f"Choose the text embedding model",
                options=NLP_Helper.TRANSFORMERS,
                help="; ".join(NLP_Helper.TRANSFORMERS_INFO)
            )

            self.representation_model = st.selectbox(
                label=f"Choose the representation model",
                options=NLP_Helper.BERTOPIC_REPRESENTATIONS,
            )

            st.button("Cluster", on_click=self.run_clustering)

    def run_clustering(self) -> None:
        self.topic_cluster = TopicClustering(keyword_df=self.text_df, text_col=self.text_col,
                                             representation_model=self.representation_model,
                                             sentence_model=self.sentence_model)

        self.topic_cluster.topic_cluster_bert()
        st.session_state.topic_cluster = self.topic_cluster

    def show_and_download_df(self):
        if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
            filtered_df = dataframe_explorer(st.session_state.topic_cluster.keyword_df)
            st.dataframe(filtered_df)
            st.download_button(
                "Press to Download",
                st.session_state.topic_cluster.keyword_df.to_csv(index=False).encode('utf-8'),
                "Clustered.csv",
                "text/csv",
                key='download-csv'
            )

    def visualize_clusters(self):
        if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
            self.n_neighbors = st.slider(label='Size of the local neighborhood', min_value=2, max_value=100, step=1)
            if st.button("Visualize Topic Clusters"):
                if (st.session_state.topic_cluster is not None) and (
                        st.session_state.topic_cluster.topic_model is not None):
                    fig = st.session_state.topic_cluster.visualize_documents(n_neighbors=self.n_neighbors)
                    fig.update_layout(title=None)
                    st.plotly_chart(fig, use_container_width=True, theme=None)

    def visualize_topic_distribution(self):
        if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
            if (st.session_state.topic_cluster is not None) and (
                    st.session_state.topic_cluster.topic_model is not None):
                fig = st.session_state.topic_cluster.visualize_topic_distribution()
                st.plotly_chart(fig, use_container_width=True, theme=None)

if __name__ == '__main__':
    topic_cluster_view = TopicClusterView()
    # tab1, tab2, tab3 = st.tabs(['Clustering Process', 'Cluster Visualization', 'Topic Distribution'])
    tab1, tab2 = st.tabs(['Clustering Process', 'Cluster Visualization'])
    with tab1:
        topic_cluster_view.input_params()
        topic_cluster_view.show_and_download_df()
    with tab2:
        topic_cluster_view.visualize_clusters()
    # with tab3:
    #     topic_cluster_view.visualize_topic_distribution()