WAVE_AI / app.py
wavesoumen's picture
Update app.py
7c7cb02 verified
raw
history blame
3.57 kB
import streamlit as st
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
from youtube_transcript_api import YouTubeTranscriptApi
# Download NLTK data
nltk.download('punkt')
# Initialize the image captioning pipeline
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
# Load the tokenizer and model for tag generation
tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-base-tag-generation")
model = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-base-tag-generation")
# Function to fetch YouTube transcript
def fetch_transcript(url):
video_id = url.split('watch?v=')[-1]
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id)
transcript_text = ' '.join([entry['text'] for entry in transcript])
return transcript_text
except Exception as e:
return str(e)
# Streamlit app title
st.title("Multi-purpose Machine Learning App")
# Create tabs for different functionalities
tab1, tab2, tab3 = st.tabs(["Image Captioning", "Text Tag Generation", "YouTube Transcript"])
# Image Captioning Tab
with tab1:
st.header("Image Captioning")
# Input for image URL
image_url = st.text_input("Enter the URL of the image:")
# If an image URL is provided
if image_url:
try:
# Display the image
st.image(image_url, caption="Provided Image", use_column_width=True)
# Generate the caption
caption = captioner(image_url)
# Display the caption
st.write("**Generated Caption:**")
st.write(caption[0]['generated_text'])
except Exception as e:
st.error(f"An error occurred: {e}")
# Text Tag Generation Tab
with tab2:
st.header("Text Tag Generation")
# Text area for user input
text = st.text_area("Enter the text for tag extraction:", height=200)
# Button to generate tags
if st.button("Generate Tags"):
if text:
try:
# Tokenize and encode the input text
inputs = tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
# Generate tags
output = model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
# Decode the output
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
# Extract unique tags
tags = list(set(decoded_output.strip().split(", ")))
# Display the tags
st.write("**Generated Tags:**")
st.write(tags)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please enter some text to generate tags.")
# YouTube Transcript Tab
with tab3:
st.header("YouTube Video Transcript Extractor")
# Input for YouTube URL
youtube_url = st.text_input("Enter YouTube URL:")
# Button to get transcript
if st.button("Get Transcript"):
if youtube_url:
transcript = fetch_transcript(youtube_url)
if "error" not in transcript.lower():
st.success("Transcript successfully fetched!")
st.text_area("Transcript", transcript, height=300)
else:
st.error(f"An error occurred: {transcript}")
else:
st.warning("Please enter a URL.")