Storyteller-1 / app.py
JSY8's picture
Update app.py
dad8282 verified
import streamlit as st
from transformers import pipeline
import textwrap
import numpy as np
import soundfile as sf
import tempfile
import os
from PIL import Image
import string
import re
# Initialize pipelines with caching
@st.cache_resource
def load_pipelines():
captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
storyer = pipeline("text-generation", model="gpt2")
tts = pipeline("text-to-speech", model="facebook/mms-tts-eng")
return captioner, storyer, tts
captioner, storyer, tts = load_pipelines()
# Function to clean raw story
def clean_generated_story(raw):
allowed_pattern = re.compile(r'[a-zA-Z0-9.,!?"\'-]+\b(?<!\b\w\b)')
return ' '.join(word for word in re.findall(allowed_pattern, raw) if len(word) > 1)
# Function to generate content from an image
def generate_content(image):
pil_image = Image.open(image)
# Generate caption
caption = captioner(pil_image)[0]["generated_text"]
st.write("**What's in the picture: 🧐**")
st.write(caption)
# Create prompt for story
prompt = (
f"Write a funny, interesting children's story centered on this scene: {caption}\n"
f"Story in third-person narrative, describing this scene exactly: {caption} "
f"Mention the exact place, location, or venue within {caption}. "
f"Avoid numbers, random letter combinations, and single-letter words.")
# Generate raw story
raw = storyer(
prompt,
max_new_tokens=100,
temperature=0.6,
top_p=0.85,
no_repeat_ngram_size=0,
return_full_text=False
)[0]["generated_text"].strip()
# Generate story and audio
caption, story, audio_path = generate_story(raw, caption, tts)
return caption, story, audio_path
# Function to generate story and audio
def generate_story(raw, caption, tts):
# Clean and trim story
story = clean_generated_story(raw)
words = story.split()
story = " ".join(words[:100])
# Display story in Streamlit
st.write("** Your funny story: πŸ“**")
st.write(story)
# Generate audio from cleaned story
chunks = textwrap.wrap(story, width=200)
audio = np.concatenate([tts(chunk)["audio"].squeeze() for chunk in chunks])
# Save audio to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
sf.write(temp_file.name, audio, tts.model.config.sampling_rate)
temp_file_path = temp_file.name
return caption, story, temp_file_path
# Streamlit UI
st.title("😎 Story Maker")
st.markdown("Upload a picture, I will generate a story for you")
uploaded_image = st.file_uploader("Choose your picture", type=["jpg", "jpeg", "png"])
# Display image
if uploaded_image is None:
st.image("https://via.placeholder.com/300", caption="Upload your picture here!", use_container_width=True)
else:
st.image(uploaded_image, caption="Your Picture", use_container_width=True)
if st.button("Generate a story"):
if uploaded_image is not None:
with st.spinner("Processing"):
try:
caption, story, audio_path = generate_content(uploaded_image)
st.success("Your story is ready! 😊")
st.audio(audio_path, format="audio/wav")
try:
os.remove(audio_path)
except OSError as e:
st.warning(f"Failed to delete temporary audio file: {e}")
except Exception as e:
st.error(f"Error generating story: {e}")
else:
st.warning("Please upload a picture first!")