import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
from youtube_transcript_api import YouTubeTranscriptApi
import torch
from textblob import TextBlob
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

# Download NLTK data
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')

# Load models and tokenizers
summary_model_name = 'utrobinmv/t5_summary_en_ru_zh_base_2048'
summary_model = T5ForConditionalGeneration.from_pretrained(summary_model_name)
summary_tokenizer = T5Tokenizer.from_pretrained(summary_model_name)

tag_tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-base-tag-generation")
tag_model = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-base-tag-generation")

captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

# Function to summarize text
def summarize_text(text, prefix):
    src_text = prefix + text
    input_ids = summary_tokenizer(src_text, return_tensors="pt")
    generated_tokens = summary_model.generate(**input_ids)
    result = summary_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    return result[0]

# Function to generate tags
def generate_tags(text):
    with torch.no_grad():
        inputs = tag_tokenizer(text, max_length=256, truncation=True, return_tensors="pt")
        output = tag_model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64, num_return_sequences=1)
        decoded_output = tag_tokenizer.batch_decode(output, skip_special_tokens=True)[0]
        tags = list(set(decoded_output.strip().split(", ")))
    return tags

# Function to fetch YouTube transcript
def fetch_transcript(url):
    video_id = url.split('watch?v=')[-1]
    try:
        transcript = YouTubeTranscriptApi.get_transcript(video_id)
        transcript_text = ' '.join([entry['text'] for entry in transcript])
        return transcript_text
    except Exception as e:
        return str(e)

# Function to extract keywords and generate hashtags
def extract_keywords(content):
    text = content.lower()
    sentences = nltk.sent_tokenize(text)
    keywords = []
    for sentence in sentences:
        words = nltk.word_tokenize(sentence)
        tags = nltk.pos_tag(words)
        for word, tag in tags:
            if tag.startswith('NN'):
                keywords.append(word)
    return keywords

def generate_hashtags(content, max_hashtags=10):
    keywords = extract_keywords(content)
    hashtags = []
    for keyword in keywords:
        hashtag = "#" + keyword
        if len(hashtag) <= 20:
            hashtags.append(hashtag)
    return hashtags[:max_hashtags]

# Function to extract point of view
def extract_point_of_view(text):
    stop_words = set(stopwords.words('english'))
    words = word_tokenize(str(text))
    filtered_words = [word for word in words if word.casefold() not in stop_words]
    text = ' '.join(filtered_words)

    blob = TextBlob(text)
    polarity = blob.sentiment.polarity
    subjectivity = blob.sentiment.subjectivity

    if polarity > 0.5:
        point_of_view = "Positive"
    elif polarity < -0.5:
        point_of_view = "Negative"
    else:
        point_of_view = "Neutral"

    return point_of_view

# Streamlit app title
st.title("Multi-purpose AI App: WAVE_AI")

# Create tabs for different functionalities
tab1, tab2, tab3, tab4, tab5 = st.tabs(["Text Summarization", "Text Tag Generation", "Image Captioning", "YouTube Transcript", "LinkedIn Post Analysis"])

# Text Summarization Tab
with tab1:
    st.header("Summarize Title Maker")

    input_text = st.text_area("Enter the text to summarize:", height=300)

    if st.button("Generate the Title"):
        if input_text:
            title1 = summarize_text(input_text, 'summary: ')
            title2 = summarize_text(input_text, 'summary brief: ')
            st.write("### Title 1")
            st.write(title1)
            st.write("### Title 2")
            st.write(title2)
        else:
            st.warning("Please enter some text to summarize.")

# Text Tag Generation Tab
with tab2:
    st.header("Tag Generation from Text")
    
    text = st.text_area("Enter the text for tag extraction:", height=200)
    
    if st.button("Generate Tags"):
        if text:
            try:
                tags = generate_tags(text)
                st.write("**Generated Tags:**")
                st.write(tags)
            except Exception as e:
                st.error(f"An error occurred: {e}")
        else:
            st.warning("Please enter some text to generate tags.")

# Image Captioning Tab
with tab3:
    st.header("Image Captioning Extractor")
    
    image_url = st.text_input("Enter the URL of the image:")
    
    if st.button("Analysis Image"):
        if image_url:
            try:
                st.image(image_url, caption="Provided Image", use_column_width=True)
                caption = captioner(image_url)
                st.write("**Generated Caption:**")
                st.write(caption[0]['generated_text'])
            except Exception as e:
                st.error(f"An error occurred: {e}")
        else:
            st.warning("Please give a image url.")

# YouTube Transcript Tab
with tab4:
    st.header("YouTube Video Transcript Extractor")
    
    youtube_url = st.text_input("Enter YouTube URL:")
    
    if st.button("Get Transcript"):
        if youtube_url:
            transcript = fetch_transcript(youtube_url)
            if "error" not in transcript.lower():
                st.success("Transcript successfully fetched!")
                st.text_area("Transcript", transcript, height=300)
            else:
                st.error(f"An error occurred: {transcript}")
        else:
            st.warning("Please enter a URL.")

# LinkedIn Post Analysis Tab
with tab5:
    st.header("LinkedIn Post Analysis AI")

    text = st.text_area("Enter the LinkedIn Post:")

    if st.button("Analyze:"):
        if text:
            # Generate tags
            tags = generate_tags(text)
            st.subheader("The Most Tracked KeyWords:")
            st.write(tags)

            # Generate summaries
            summary1 = summarize_text(text, 'summary: ')
            summary2 = summarize_text(text, 'summary brief: ')
            st.subheader("Summary Title 1:")
            st.write(summary1)
            st.subheader("Summary Title 2:")
            st.write(summary2)

            # Generate hashtags
            hashtags = generate_hashtags(text)
            st.subheader("Generated Hashtags for the Post")
            st.write(hashtags)

            # Extract point of view
            point_of_view = extract_point_of_view(text)
            st.subheader("Tone of the Post:")
            st.write(point_of_view)
        else:
            st.warning("Please enter text to analyze.")