google_ads_space / pages /2_Topic_Cluster.py
zayed-upal
Google ads format download added, topic name rename option added
4c25316
import pandas as pd
import streamlit as st
from Functionalities import NLP_Helper
from Functionalities.TopicClustering import TopicClustering
from streamlit_extras.dataframe_explorer import dataframe_explorer
class TopicClusterView:
def __init__(self):
self.n_neighbors = 10
self.topic_cluster = None
self.representation_model = None
self.sentence_model = None
self.text_col = None
self.text_df = None
self.text_file = None
st.session_state.topic_cluster = None \
if 'topic_cluster' not in st.session_state else st.session_state.topic_cluster
st.set_page_config(page_title='Topic Clustering', layout="wide")
st.header("Topic Clustering")
# st.write(f"This page tries to predict the suitable ad group for new keywords "
# f"based on the keywords already existing in the campaign.")
def input_params(self) -> None:
"""
Takes csv file input, name of text col, select option for sentence model and representation model
:return:
"""
self.text_file = st.file_uploader(label="Upload the CSV file containing the texts to cluster")
if self.text_file:
self.text_df = pd.read_csv(self.text_file)
self.text_col = st.selectbox(
label=f"Choose the column to use for topic clustering in **{self.text_file.name}**",
options=self.text_df.columns
)
self.sentence_model = st.selectbox(
label=f"Choose the text embedding model",
options=NLP_Helper.TRANSFORMERS,
help="; ".join(NLP_Helper.TRANSFORMERS_INFO)
)
self.representation_model = st.selectbox(
label=f"Choose the representation model",
options=NLP_Helper.BERTOPIC_REPRESENTATIONS,
)
st.button("Cluster", on_click=self.run_clustering)
def run_clustering(self) -> None:
self.topic_cluster = TopicClustering(keyword_df=self.text_df, text_col=self.text_col,
representation_model=self.representation_model,
sentence_model=self.sentence_model)
self.topic_cluster.topic_cluster_bert()
st.session_state.topic_cluster = self.topic_cluster
def show_and_download_df(self):
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
filtered_df = dataframe_explorer(st.session_state.topic_cluster.keyword_df)
st.dataframe(filtered_df)
with st.expander("Rename Topics"):
for topic_name in st.session_state.topic_cluster.topic_names:
cur_topic_col, new_topic_col = st.columns(2)
with cur_topic_col:
cur_topic_col.write(topic_name)
with new_topic_col:
st.session_state.topic_cluster.topic_name_mapping[topic_name] = \
st.text_input("New topic name", topic_name)
if st.button("Update Topic Names"):
st.session_state.topic_cluster.update_topic_names()
st.experimental_rerun()
st.download_button(
"Press to Download as CSV",
st.session_state.topic_cluster.keyword_df.to_csv(index=False).encode('utf-8'),
"Clustered.csv",
"text/csv",
key='download-csv'
)
with st.expander("Download as CSV for Bulk upload in Google Ads"):
campaign_name = st.text_input("Campaign Name", "Demo Campaign")
st.dataframe(st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name))
st.download_button(
"Download as CSV for Bulk upload in Google Ads",
st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name).to_csv(
index=False).encode('utf-8'),
f"{campaign_name}_keywords_upload.csv",
"text/csv",
key='download-google-csv'
)
def visualize_clusters(self):
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
self.n_neighbors = st.slider(label='Size of the local neighborhood', min_value=2, max_value=100, step=1)
if st.button("Visualize Topic Clusters"):
if (st.session_state.topic_cluster is not None) and (
st.session_state.topic_cluster.topic_model is not None):
fig = st.session_state.topic_cluster.visualize_documents(n_neighbors=self.n_neighbors)
fig.update_layout(title=None)
st.plotly_chart(fig, use_container_width=True, theme=None)
def visualize_topic_distribution(self):
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None):
if (st.session_state.topic_cluster is not None) and (
st.session_state.topic_cluster.topic_model is not None):
fig = st.session_state.topic_cluster.visualize_topic_distribution()
st.plotly_chart(fig, use_container_width=True, theme=None)
if __name__ == '__main__':
topic_cluster_view = TopicClusterView()
# tab1, tab2, tab3 = st.tabs(['Clustering Process', 'Cluster Visualization', 'Topic Distribution'])
tab1, tab2 = st.tabs(['Clustering Process', 'Cluster Visualization'])
with tab1:
topic_cluster_view.input_params()
topic_cluster_view.show_and_download_df()
with tab2:
topic_cluster_view.visualize_clusters()
# with tab3:
# topic_cluster_view.visualize_topic_distribution()