Spaces:
Runtime error
Runtime error
import pandas as pd | |
import streamlit as st | |
from Functionalities import NLP_Helper | |
from Functionalities.TopicClustering import TopicClustering | |
from streamlit_extras.dataframe_explorer import dataframe_explorer | |
class TopicClusterView: | |
def __init__(self): | |
self.n_neighbors = 10 | |
self.topic_cluster = None | |
self.representation_model = None | |
self.sentence_model = None | |
self.text_col = None | |
self.text_df = None | |
self.text_file = None | |
st.session_state.topic_cluster = None \ | |
if 'topic_cluster' not in st.session_state else st.session_state.topic_cluster | |
st.set_page_config(page_title='Topic Clustering', layout="wide") | |
st.header("Topic Clustering") | |
# st.write(f"This page tries to predict the suitable ad group for new keywords " | |
# f"based on the keywords already existing in the campaign.") | |
def input_params(self) -> None: | |
""" | |
Takes csv file input, name of text col, select option for sentence model and representation model | |
:return: | |
""" | |
self.text_file = st.file_uploader(label="Upload the CSV file containing the texts to cluster") | |
if self.text_file: | |
self.text_df = pd.read_csv(self.text_file) | |
self.text_col = st.selectbox( | |
label=f"Choose the column to use for topic clustering in **{self.text_file.name}**", | |
options=self.text_df.columns | |
) | |
self.sentence_model = st.selectbox( | |
label=f"Choose the text embedding model", | |
options=NLP_Helper.TRANSFORMERS, | |
help="; ".join(NLP_Helper.TRANSFORMERS_INFO) | |
) | |
self.representation_model = st.selectbox( | |
label=f"Choose the representation model", | |
options=NLP_Helper.BERTOPIC_REPRESENTATIONS, | |
) | |
st.button("Cluster", on_click=self.run_clustering) | |
def run_clustering(self) -> None: | |
self.topic_cluster = TopicClustering(keyword_df=self.text_df, text_col=self.text_col, | |
representation_model=self.representation_model, | |
sentence_model=self.sentence_model) | |
self.topic_cluster.topic_cluster_bert() | |
st.session_state.topic_cluster = self.topic_cluster | |
def show_and_download_df(self): | |
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): | |
filtered_df = dataframe_explorer(st.session_state.topic_cluster.keyword_df) | |
st.dataframe(filtered_df) | |
with st.expander("Rename Topics"): | |
for topic_name in st.session_state.topic_cluster.topic_names: | |
cur_topic_col, new_topic_col = st.columns(2) | |
with cur_topic_col: | |
cur_topic_col.write(topic_name) | |
with new_topic_col: | |
st.session_state.topic_cluster.topic_name_mapping[topic_name] = \ | |
st.text_input("New topic name", topic_name) | |
if st.button("Update Topic Names"): | |
st.session_state.topic_cluster.update_topic_names() | |
st.experimental_rerun() | |
st.download_button( | |
"Press to Download as CSV", | |
st.session_state.topic_cluster.keyword_df.to_csv(index=False).encode('utf-8'), | |
"Clustered.csv", | |
"text/csv", | |
key='download-csv' | |
) | |
with st.expander("Download as CSV for Bulk upload in Google Ads"): | |
campaign_name = st.text_input("Campaign Name", "Demo Campaign") | |
st.dataframe(st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name)) | |
st.download_button( | |
"Download as CSV for Bulk upload in Google Ads", | |
st.session_state.topic_cluster.get_df_in_google_ads_format(campaign_name).to_csv( | |
index=False).encode('utf-8'), | |
f"{campaign_name}_keywords_upload.csv", | |
"text/csv", | |
key='download-google-csv' | |
) | |
def visualize_clusters(self): | |
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): | |
self.n_neighbors = st.slider(label='Size of the local neighborhood', min_value=2, max_value=100, step=1) | |
if st.button("Visualize Topic Clusters"): | |
if (st.session_state.topic_cluster is not None) and ( | |
st.session_state.topic_cluster.topic_model is not None): | |
fig = st.session_state.topic_cluster.visualize_documents(n_neighbors=self.n_neighbors) | |
fig.update_layout(title=None) | |
st.plotly_chart(fig, use_container_width=True, theme=None) | |
def visualize_topic_distribution(self): | |
if (st.session_state.topic_cluster is not None) and (st.session_state.topic_cluster.topic_model is not None): | |
if (st.session_state.topic_cluster is not None) and ( | |
st.session_state.topic_cluster.topic_model is not None): | |
fig = st.session_state.topic_cluster.visualize_topic_distribution() | |
st.plotly_chart(fig, use_container_width=True, theme=None) | |
if __name__ == '__main__': | |
topic_cluster_view = TopicClusterView() | |
# tab1, tab2, tab3 = st.tabs(['Clustering Process', 'Cluster Visualization', 'Topic Distribution']) | |
tab1, tab2 = st.tabs(['Clustering Process', 'Cluster Visualization']) | |
with tab1: | |
topic_cluster_view.input_params() | |
topic_cluster_view.show_and_download_df() | |
with tab2: | |
topic_cluster_view.visualize_clusters() | |
# with tab3: | |
# topic_cluster_view.visualize_topic_distribution() | |