File size: 5,951 Bytes
1ee5c89
 
 
 
 
 
4c25316
1ee5c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c25316
 
 
 
 
 
 
 
 
 
 
 
 
1ee5c89
4c25316
1ee5c89
 
 
 
 
 
4c25316
 
 
 
 
 
 
 
 
 
 
 
1ee5c89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c25316
1ee5c89
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import pandas as pd
import streamlit as st
from Functionalities import NLP_Helper
from Functionalities.TopicClustering import TopicClustering
from streamlit_extras.dataframe_explorer import dataframe_explorer


class TopicClusterView:
    def __init__(self):
        self.n_neighbors = 10
        self.topic_cluster = None
        self.representation_model = None
        self.sentence_model = None
        self.text_col = None
        self.text_df = None
        self.text_file = None

        st.session_state.topic_cluster = None \
            if 'topic_cluster' not in st.session_state else st.session_state.topic_cluster

        st.set_page_config(page_title='Topic Clustering', layout="wide")
        st.header("Topic Clustering")
        # st.write(f"This page tries to predict the suitable ad group for new keywords "
        #          f"based on the keywords already existing in the campaign.")

    def input_params(self) -> None:
        """
        Takes csv file input, name of text col, select option for sentence model and representation model
        :return:
        """
        self.text_file = st.file_uploader(label="Upload the CSV file containing the texts to cluster")
        if self.text_file:
            self.text_df = pd.read_csv(self.text_file)
            self.text_col = st.selectbox(
                label=f"Choose the column to use for topic clustering in **{self.text_file.name}**",
                options=self.text_df.columns
            )

            self.sentence_model = st.selectbox(
                label=f"Choose the text embedding model",
                options=NLP_Helper.TRANSFORMERS,
                help="; ".join(NLP_Helper.TRANSFORMERS_INFO)
            )

            self.representation_model = st.selectbox(
                label=f"Choose the representation model",
                options=NLP_Helper.BERTOPIC_REPRESENTATIONS,
            )

            st.button("Cluster", on_click=self.run_clustering)

    def run_clustering(self) -> None:
        self.topic_cluster = TopicClustering(keyword_df=self.text_df, text_col=self.text_col,
                                             representation_model=self.representation_model,
                                             sentence_model=self.sentence_model)

        self.topic_cluster.topic_cluster_bert()
        st.session_state.topic_cluster = self.topic_cluster

    def show_and_download_df(self):
        if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
            filtered_df = dataframe_explorer(st.session_state.topic_cluster.keyword_df)
            st.dataframe(filtered_df)
            with st.expander("Rename Topics"):
                for topic_name in st.session_state.topic_cluster.topic_names:
                    cur_topic_col, new_topic_col = st.columns(2)
                    with cur_topic_col:
                        cur_topic_col.write(topic_name)
                    with new_topic_col:
                        st.session_state.topic_cluster.topic_name_mapping[topic_name] = \
                            st.text_input("New topic name", topic_name)

                if st.button("Update Topic Names"):
                    st.session_state.topic_cluster.update_topic_names()
                    st.experimental_rerun()

            st.download_button(
                "Press to Download as CSV",
                st.session_state.topic_cluster.keyword_df.to_csv(index=False).encode('utf-8'),
                "Clustered.csv",
                "text/csv",
                key='download-csv'
            )

            with st.expander("Download as CSV for Bulk upload in Google Ads"):
                campaign_name = st.text_input("Campaign Name", "Demo Campaign")
                st.dataframe(st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name))
                st.download_button(
                    "Download as CSV for Bulk upload in Google Ads",
                    st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name).to_csv(
                        index=False).encode('utf-8'),
                    f"{campaign_name}_keywords_upload.csv",
                    "text/csv",
                    key='download-google-csv'
                )

    def visualize_clusters(self):
        if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
            self.n_neighbors = st.slider(label='Size of the local neighborhood', min_value=2, max_value=100, step=1)
            if st.button("Visualize Topic Clusters"):
                if (st.session_state.topic_cluster is not None) and (
                        st.session_state.topic_cluster.topic_model is not None):
                    fig = st.session_state.topic_cluster.visualize_documents(n_neighbors=self.n_neighbors)
                    fig.update_layout(title=None)
                    st.plotly_chart(fig, use_container_width=True, theme=None)

    def visualize_topic_distribution(self):
        if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
            if (st.session_state.topic_cluster is not None) and (
                    st.session_state.topic_cluster.topic_model is not None):
                fig = st.session_state.topic_cluster.visualize_topic_distribution()
                st.plotly_chart(fig, use_container_width=True, theme=None)


if __name__ == '__main__':
    topic_cluster_view = TopicClusterView()
    # tab1, tab2, tab3 = st.tabs(['Clustering Process', 'Cluster Visualization', 'Topic Distribution'])
    tab1, tab2 = st.tabs(['Clustering Process', 'Cluster Visualization'])
    with tab1:
        topic_cluster_view.input_params()
        topic_cluster_view.show_and_download_df()
    with tab2:
        topic_cluster_view.visualize_clusters()
    # with tab3:
    #     topic_cluster_view.visualize_topic_distribution()