tryAI / app.py
wavesoumen's picture
Update app.py
ae83552 verified
raw
history blame
2.4 kB
import streamlit as st
import torch
import requests
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
nltk.download('punkt')
@st.cache_resource
def load_models():
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
tokenizer = AutoTokenizer.from_pretrained("fabiochiu/t5-base-tag-generation")
model2 = AutoModelForSeq2SeqLM.from_pretrained("fabiochiu/t5-base-tag-generation")
return processor, model, tokenizer, model2
processor, model, tokenizer, model2 = load_models()
def get_image_caption_and_tags(img_url):
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
# conditional image captioning
alltexts = "a photography of"
inputs = processor(raw_image, alltexts, return_tensors="pt")
out = model.generate(**inputs)
conditional_caption = processor.decode(out[0], skip_special_tokens=True)
# unconditional image captioning
inputs = processor(raw_image, return_tensors="pt")
out = model.generate(**inputs)
unconditional_caption = processor.decode(out[0], skip_special_tokens=True)
inputs = tokenizer([alltexts], max_length=512, truncation=True, return_tensors="pt")
output = model2.generate(**inputs, num_beams=8, do_sample=True, min_length=10, max_length=64)
decoded_output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
tags = list(set(decoded_output.strip().split(", ")))
return raw_image, conditional_caption, unconditional_caption, tags
st.title('Image Captioning and Tag Generation')
img_url = st.text_input("Enter Image URL:")
if st.button("Generate Captions and Tags"):
with st.spinner('Processing...'):
try:
image, cond_caption, uncond_caption, tags = get_image_caption_and_tags(img_url)
st.image(image, caption='Input Image', use_column_width=True)
st.subheader("Conditional Caption:")
st.write(cond_caption)
st.subheader("Unconditional Caption:")
st.write(uncond_caption)
st.subheader("Generated Tags:")
st.write(", ".join(tags))
except Exception as e:
st.error(f"An error occurred: {e}")