WAVE_AI / app.py
wavesoumen's picture
Update app.py
a80511b verified
raw
history blame
4.92 kB
import streamlit as st
import requests
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
from youtube_transcript_api import YouTubeTranscriptApi
# Download NLTK data
nltk.download('punkt')
# Initialize the image captioning processor and model
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# Initialize the tokenizer and model for tag generation
tag_tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-base-tag-generation")
tag_model = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-base-tag-generation")
# Function to generate captions for an image
def generate_caption(img_url, text="a photography of"):
try:
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
except Exception as e:
st.error(f"Error loading image: {e}")
return None, None
# Conditional image captioning
inputs_conditional = caption_processor(raw_image, text, return_tensors="pt")
out_conditional = caption_model.generate(**inputs_conditional)
caption_conditional = caption_processor.decode(out_conditional[0], skip_special_tokens=True)
# Unconditional image captioning
inputs_unconditional = caption_processor(raw_image, return_tensors="pt")
out_unconditional = caption_model.generate(**inputs_unconditional)
caption_unconditional = caption_processor.decode(out_unconditional[0], skip_special_tokens=True)
return caption_conditional, caption_unconditional
# Function to fetch YouTube transcript
def fetch_transcript(url):
video_id = url.split('watch?v=')[-1]
try:
transcript = YouTubeTranscriptApi.get_transcript(video_id)
transcript_text = ' '.join([entry['text'] for entry in transcript])
return transcript_text
except Exception as e:
return str(e)
# Streamlit app title
st.title("Multi-purpose Machine Learning App")
# Create tabs for different functionalities
tab1, tab2, tab3 = st.tabs(["Image Captioning", "Text Tag Generation", "YouTube Transcript"])
# Image Captioning Tab
with tab1:
st.header("Image Captioning")
# Input for image URL
img_url = st.text_input("Enter Image URL:")
# If an image URL is provided
if st.button("Generate Captions", key='caption_button'):
if img_url:
caption_conditional, caption_unconditional = generate_caption(img_url)
if caption_conditional and caption_unconditional:
st.success("Captions successfully generated!")
st.image(img_url, caption="Input Image", use_column_width=True)
st.write("### Conditional Caption")
st.write(caption_conditional)
st.write("### Unconditional Caption")
st.write(caption_unconditional)
else:
st.warning("Please enter an image URL.")
# Text Tag Generation Tab
with tab2:
st.header("Text Tag Generation")
# Text area for user input
text = st.text_area("Enter the text for tag extraction:", height=200)
# Button to generate tags
if st.button("Generate Tags", key='tag_button'):
if text:
try:
# Tokenize and encode the input text
inputs = tag_tokenizer([text], max_length=512, truncation=True, return_tensors="pt")
# Generate tags
output = tag_model.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
# Decode the output
decoded_output = tag_tokenizer.batch_decode(output, skip_special_tokens=True)[0]
# Extract unique tags
tags = list(set(decoded_output.strip().split(", ")))
# Display the tags
st.write("**Generated Tags:**")
st.write(tags)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please enter some text to generate tags.")
# YouTube Transcript Tab
with tab3:
st.header("YouTube Video Transcript Extractor")
# Input for YouTube URL
youtube_url = st.text_input("Enter YouTube URL:")
# Button to get transcript
if st.button("Get Transcript", key='transcript_button'):
if youtube_url:
transcript = fetch_transcript(youtube_url)
if "error" not in transcript.lower():
st.success("Transcript successfully fetched!")
st.text_area("Transcript", transcript, height=300)
else:
st.error(f"An error occurred: {transcript}")
else:
st.warning("Please enter a URL.")