import pandas as pd
import tweepy
import re
import emoji
import spacy
import gensim
import json
import string

from spacy.tokenizer import Tokenizer
from gensim.parsing.preprocessing import STOPWORDS as SW
from wordcloud import STOPWORDS

from gensim.corpora import Dictionary
from gensim.models.coherencemodel import CoherenceModel
from pprint import pprint

import numpy as np
import tqdm

from gensim.parsing.preprocessing import preprocess_string, strip_punctuation, strip_numeric

import torch
from transformers import T5ForConditionalGeneration,T5Tokenizer
from googletrans import Translator

from bertopic import BERTopic
from umap import UMAP
from sklearn.feature_extraction.text import CountVectorizer

from operator import itemgetter

import gradio as gr

global df
bearer_token = 'AAAAAAAAAAAAAAAAAAAAACEigwEAAAAACoP8KHJYLOKCL4OyB9LEPV00VB0%3DmyeDROUvw4uipHwvbPPfnTuY0M9ORrLuXrMvcByqZhwo3SUc4F'
client = tweepy.Client(bearer_token=bearer_token)
nlp = spacy.load('en_core_web_lg')
print('hi')

def scrape(keywords):
    query = keywords + ' (lang:en OR lang:tl) -is:retweet'
    max_results = 100
    tweet_fields=['geo', 'id', 'lang', 'created_at']
    expansions=['geo.place_id']
    place_fields = ['contained_within', 'country', 'country_code', 'full_name', 'geo', 'id', 'name', 'place_type']

    response = client.search_recent_tweets(
        query=query,
        max_results=max_results,
        tweet_fields=tweet_fields,
        expansions=expansions,
        place_fields=place_fields
    )

    tweets = []
    for x in response[0]:
        tweets.append(str(x))

    place_data = response[1]

    df = pd.DataFrame(tweets, columns=['tweet'])

    return place_data

def get_example(dataset):
    df = pd.read_csv(dataset + '.csv')
    return df

def give_emoji_free_text(text):
    """
    Removes emoji's from tweets
    Accepts:
        Text (tweets)
    Returns:
        Text (emoji free tweets)
    """
    emoji_list = [c for c in text if c in emoji.EMOJI_DATA]
    clean_text = ' '.join([str for str in text.split() if not any(i in str for i in emoji_list)])
    return clean_text

def url_free_text(text):
    '''
    Cleans text from urls
    '''
    text = re.sub(r'http\S+', '', text)
    return text

def get_lemmas(text):
        '''Used to lemmatize the processed tweets'''
        lemmas = []

        doc = nlp(text)

        for token in doc:
            if ((token.is_stop == False) and (token.is_punct == False)) and (token.pos_ != 'PRON'):
                lemmas.append(token.lemma_)

        return lemmas

# Tokenizer function
def tokenize(text):
    """
    Parses a string into a list of semantic units (words)
    Args:
        text (str): The string that the function will tokenize.
    Returns:
        list: tokens parsed out
    """
    # Removing url's
    pattern = r"http\S+"

    tokens = re.sub(pattern, "", text) # https://www.youtube.com/watch?v=O2onA4r5UaY
    tokens = re.sub('[^a-zA-Z 0-9]', '', text)
    tokens = re.sub('[%s]' % re.escape(string.punctuation), '', text) # Remove punctuation
    tokens = re.sub('\w*\d\w*', '', text) # Remove words containing numbers
    # tokens = re.sub('@*!*$*', '', text) # Remove @ ! $
    tokens = tokens.strip(',') # TESTING THIS LINE
    tokens = tokens.strip('?') # TESTING THIS LINE
    tokens = tokens.strip('!') # TESTING THIS LINE
    tokens = tokens.strip("'") # TESTING THIS LINE
    tokens = tokens.strip(".") # TESTING THIS LINE

    tokens = tokens.lower().split() # Make text lowercase and split it

    return tokens
    

def cleaning(df):
    df.rename(columns = {'tweet':'original_tweets'}, inplace = True)

    # Apply the function above and get tweets free of emoji's
    call_emoji_free = lambda x: give_emoji_free_text(x)

    # Apply `call_emoji_free` which calls the function to remove all emoji's
    df['emoji_free_tweets'] = df['original_tweets'].apply(call_emoji_free)

    #Create a new column with url free tweets
    df['url_free_tweets'] = df['emoji_free_tweets'].apply(url_free_text)

    

    f = open('stopwords-tl.json')
    tlStopwords = json.loads(f.read())
    stopwords = set(STOPWORDS)
    stopwords.update(tlStopwords)
    stopwords.update(['na', 'sa', 'ko', 'ako', 'ng', 'mga', 'ba', 'ka', 'yung', 'lang', 'di', 'mo', 'kasi'])

    # Tokenizer
    tokenizer = Tokenizer(nlp.vocab)


    # Custom stopwords
    custom_stopwords = ['hi','\n','\n\n', '&', ' ', '.', '-', 'got', "it's", 'it’s', "i'm", 'i’m', 'im', 'want', 'like', '$', '@']


    # Customize stop words by adding to the default list
    STOP_WORDS = nlp.Defaults.stop_words.union(custom_stopwords)

    # ALL_STOP_WORDS = spacy + gensim + wordcloud
    ALL_STOP_WORDS = STOP_WORDS.union(SW).union(stopwords)


    tokens = []
    STOP_WORDS.update(stopwords)

    for doc in tokenizer.pipe(df['url_free_tweets'], batch_size=500):
        doc_tokens = []
        for token in doc:
            if token.text.lower() not in STOP_WORDS:
                doc_tokens.append(token.text.lower())
        tokens.append(doc_tokens)

    # Makes tokens column
    df['tokens'] = tokens

    # Make tokens a string again
    df['tokens_back_to_text'] = [' '.join(map(str, l)) for l in df['tokens']]

    df['lemmas'] = df['tokens_back_to_text'].apply(get_lemmas)

    # Make lemmas a string again
    df['lemmas_back_to_text'] = [' '.join(map(str, l)) for l in df['lemmas']]

    # Apply tokenizer
    df['lemma_tokens'] = df['lemmas_back_to_text'].apply(tokenize)  

def split_corpus(corpus, n):
    for i in range(0, len(corpus), n):
        corpus_split = corpus
        yield corpus_split[i:i + n]

def compute_coherence_values_base_lda(dictionary, corpus, texts, limit, coherence, start=2, step=1):
    coherence_values = []
    model_list = []
    for num_topics in range(start, limit, step):
        model = gensim.models.ldamodel.LdaModel(corpus=corpus,
                                                num_topics=num_topics,
                                                random_state=100,
                                                chunksize=200,
                                                passes=10,
                                                per_word_topics=True,
                                                id2word=id2word)
        model_list.append(model)
        coherencemodel = CoherenceModel(model=model, texts=texts, dictionary=dictionary, coherence=coherence)
        coherence_values.append(coherencemodel.get_coherence())

    return model_list, coherence_values

def base_lda():
    # Create a id2word dictionary
    global id2word
    id2word = Dictionary(df['lemma_tokens'])

    # Filtering Extremes
    id2word.filter_extremes(no_below=2, no_above=.99)

    # Creating a corpus object
    global corpus
    corpus = [id2word.doc2bow(d) for d in df['lemma_tokens']]
    global corpus_og
    corpus_og = [id2word.doc2bow(d) for d in df['lemma_tokens']]

    corpus_split = corpus
    split_corpus(corpus_split, 5)

    global coherence
    coherence = 'c_v'

    coherence_averages = [0] * 8
    for i in range(5):
        training_corpus = corpus_split
        training_corpus.remove(training_corpus[i])
        print(training_corpus[i])
        model_list, coherence_values = compute_coherence_values_base_lda(dictionary=id2word, corpus=training_corpus,
                                                            texts=df['lemma_tokens'],
                                                            start=2,
                                                            limit=10,
                                                            step=1,
                                                            coherence=coherence)
        for j in range(len(coherence_values)):
            coherence_averages[j] += coherence_values[j]

        limit = 10; start = 2; step = 1;
        x = range(start, limit, step)

    coherence_averages = [x / 5 for x in coherence_averages]

    if coherence == 'c_v':
        k_max = max(coherence_averages)
    else:
        k_max = min(coherence_averages, key=abs)

    global num_topics
    num_topics = coherence_averages.index(k_max) + 2
    
def compute_coherence_values2(corpus, dictionary, k, a, b):
    lda_model = gensim.models.ldamodel.LdaModel(corpus=corpus,
        id2word=id2word,
        num_topics=num_topics,
        random_state=100,
        chunksize=200,
        passes=10,
        alpha=a,
        eta=b,
        per_word_topics=True)
    coherence_model_lda = CoherenceModel(model=lda_model,
        texts=df['lemma_tokens'],
        dictionary=id2word,
        coherence='c_v')

    return coherence_model_lda.get_coherence()

def hyperparameter_optimization():
    grid = {}
    grid['Validation_Set'] = {}

    min_topics = 1
    max_topics = 10
    step_size = 1
    topics_range = range(min_topics, max_topics, step_size)

    alpha = [0.05, 0.1, 0.5, 1, 5, 10]
    # alpha.append('symmetric')
    # alpha.append('asymmetric')

    beta = [0.05, 0.1, 0.5, 1, 5, 10]
    # beta.append('symmetric')

    num_of_docs = len(corpus_og)
    corpus_sets = [gensim.utils.ClippedCorpus(corpus_og, int(num_of_docs*0.75)),
                corpus_og]
    corpus_title = ['75% Corpus', '100% Corpus']
    model_results = {'Validation_Set': [],
                    'Alpha': [],
                    'Beta': [],
                    'Coherence': []
                    }
    if 1 == 1:
        pbar = tqdm.tqdm(total=540)

    for i in range(len(corpus_sets)):
        for a in alpha:
            for b in beta:
                cv = compute_coherence_values2(corpus=corpus_sets[i],
                                            dictionary=id2word,
                                            k=num_topics,
                                            a=a,
                                            b=b)
                model_results['Validation_Set'].append(corpus_title[i])
                model_results['Alpha'].append(a)
                model_results['Beta'].append(b)
                model_results['Coherence'].append(cv)

            pbar.update(1)
    pd.DataFrame(model_results).to_csv('lda_tuning_results_new.csv', index=False)
    pbar.close()

    params_df = pd.read_csv('lda_tuning_results_new.csv')
    params_df = params_df[params_df.Validation_Set == '75% Corpus']
    params_df.reset_index(inplace=True)
    params_df = params_df.replace(np.inf, -np.inf)
    max_params = params_df.loc[params_df['Coherence'].idxmax()]
    max_coherence = max_params['Coherence']
    max_alpha = max_params['Alpha']
    max_beta = max_params['Beta']
    max_validation_set = max_params['Validation_Set']

    global lda_model_final
    lda_model_final = gensim.models.ldamodel.LdaModel(corpus=corpus_og,
        id2word=id2word,
        num_topics=num_topics,
        random_state=100,
        chunksize=200,
        passes=10,
        alpha=max_alpha,
        eta=max_beta,
        per_word_topics=True)
    
    coherence_model_lda = CoherenceModel(model=lda_model_final, texts=df['lemma_tokens'], dictionary=id2word,
                                     coherence='c_v')
    coherence_lda = coherence_model_lda.get_coherence()

    return coherence_lda

def assignMaxTopic(l):
    maxTopic = max(l,key=itemgetter(1))[0]
    return maxTopic

def assignTopic(l):
    topics = []
    for x in l:
        topics.append(x[0])

def topic_assignment(df):
    lda_topics = lda_model_final.show_topics(num_words=10)

    topics = []
    filters = [lambda x: x.lower(), strip_punctuation, strip_numeric]

    for topic in lda_topics:
        topics.append(preprocess_string(topic[1], filters))

    df['topic'] = [sorted(lda_model_final[corpus_og][text][0]) for text in range(len(df['original_tweets']))]

    df = df[df['topic'].map(lambda d: len(d)) > 0]
    df['max_topic'] = df['topic'].map(lambda row: assignMaxTopic(row))

    global topic_clusters
    topic_clusters = []
    for i in range(num_topics):
        topic_clusters.append(df[df['max_topic'].isin(([i]))])
        topic_clusters[i] = topic_clusters[i]['original_tweets'].tolist()
    
def get_topic_value(row, i):
    if len(row) == 1:
        return row[0][1]
    else:
        try:
            return row[i][1]
        except Exception as e:
            print(e)

def reprsentative_tweets():
    global top_tweets
    top_tweets = []
    for i in range(len(topic_clusters)):
        tweets = df.loc[df['max_topic'] == i]
        tweets['topic'] = tweets['topic'].apply(lambda x: get_topic_value(x, i))
        # tweets['topic'] = [row[i][1] for row in tweets['topic']]
        tweets_sorted = tweets.sort_values('topic', ascending=False)
        tweets_sorted.drop_duplicates(subset=['original_tweets'])
        rep_tweets = tweets_sorted['original_tweets']
        rep_tweets = [*set(rep_tweets)]
        top_tweets.append(rep_tweets[:5])
        # print('Topic ', i)
        # print(rep_tweets[:5])
    return top_tweets

def topic_summarization(topic_groups):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = T5ForConditionalGeneration.from_pretrained("Michau/t5-base-en-generate-headline")
    tokenizer = T5Tokenizer.from_pretrained("Michau/t5-base-en-generate-headline")
    model = model.to(device)
    translator = Translator()

    headlines = []
    for i in range(len(topic_groups)):
        tweets = " ".join(topic_groups[i])
        # print(tweets)
        out = translator.translate(tweets, dest='en')
        text = out.text
        # print(tweets)

        max_len = 256

        encoding = tokenizer.encode_plus(text, return_tensors = "pt")
        input_ids = encoding["input_ids"].to(device)
        attention_masks = encoding["attention_mask"].to(device)

        beam_outputs = model.generate(
            input_ids = input_ids,
            attention_mask = attention_masks,
            max_length = 64,
            num_beams = 3,
            early_stopping = True,
        )

        result = tokenizer.decode(beam_outputs[0])
        headlines += "Topic " + str(i) + " " + result

    return headlines

def compute_coherence_value_bertopic(topic_model):
    topic_words = [[words for words, _ in topic_model.get_topic(topic)] for topic in range(len(set(topics))-1)]
    coherence_model = CoherenceModel(topics=topic_words,
                                    texts=df['lemma_tokens'],
                                    corpus=corpus,
                                    dictionary=id2word,
                                    coherence=coherence)
    coherence_score = coherence_model.get_coherence()

    return coherence_score

def base_bertopic():
    df['lemma_tokens_string'] = df['lemma_tokens'].apply(lambda x: ' '.join(x))
    global id2word
    id2word = Dictionary(df['lemma_tokens'])
    global corpus
    corpus = [id2word.doc2bow(d) for d in df['lemma_tokens']]

    global umap_model
    umap_model = UMAP(n_neighbors=15,
        n_components=5,
        min_dist=0.0,
        metric='cosine',
        random_state=100)
    
    base_topic_model = BERTopic(umap_model=umap_model, language="english", calculate_probabilities=True)

    topics, probabilities = base_topic_model.fit_transform(df['lemma_tokens_string'])

    try:
        print(compute_coherence_value_bertopic(base_topic_model))
    except:
        print('Unable to generate meaningful topics (Base BERTopic model)')

def optimized_bertopic():
    vectorizer_model = CountVectorizer(max_features=1_000, stop_words="english")
    optimized_topic_model = BERTopic(umap_model=umap_model, 
            language="multilingual", 
            n_gram_range=(1, 3), 
            vectorizer_model=vectorizer_model, 
            calculate_probabilities=True)

    topics, probabilities = optimized_topic_model.fit_transform(df['lemma_tokens_string'])

    try:
        print(compute_coherence_value_bertopic(optimized_topic_model))
    except:
        print('Unable to generate meaningful topics, base BERTopic model if possible')

    rep_docs = optimized_topic_model.representative_docs_

    global top_tweets
    top_tweets = []

    for topic in rep_docs:
        if topic == -1:
            print('test')
            continue
        topic_docs = rep_docs.get(topic)

        tweets = []
        for doc in topic_docs:
            index = df.isin([doc]).any(axis=1).idxmax()
            # print(index)
            tweets.append(df.loc[index, 'original_tweets'])
            print(tweets)
        top_tweets.append(tweets)

global examples

def main(dataset, model):
    global df
    examples = [ "katip,katipunan",
        "bgc,bonifacio global city",
        "pobla,poblacion",
        "cubao",
        "taft"
    ]
    keyword_list = dataset.split(',')
    if len(keyword_list) > 1:
        keywords = '(' + ' OR '.join(keyword_list) + ')'
    else:
        keywords = keyword_list[0]
    if dataset in examples:
        df = get_example(keywords)
        place_data = 'test'
    else:
        print(dataset)
        place_data = str(scrape(keyword_list))
    print(df)
    cleaning(df)

    print(df)
    if model == 'LDA':
        base_lda()
        coherence = hyperparameter_optimization()
        topic_assignment(df)
        top_tweets = reprsentative_tweets()
    else:
        base_bertopic()
        optimized_bertopic()

    headlines = topic_summarization(top_tweets)
    headlines = '\n'.join(str(h) for h in headlines)



    return place_data, headlines


iface = gr.Interface(fn=main, 
                    inputs=[gr.Dropdown(["katip,katipunan",
                                        "bgc,bonifacio global city",
                                        "cubao",
                                        "taft",
                                        "pobla,poblacion"],
                                        label="Dataset"),
                            gr.Dropdown(["LDA", 
                                        "BERTopic"],
                                        label="Model")
                            ],
                    # examples=examples,
                    outputs=["text",
                            "text"]
                    )
iface.launch()