File size: 3,641 Bytes
348f72a
 
d4ff40b
 
 
 
 
 
 
c7bc4b8
348f72a
d4ff40b
 
 
dad8282
2b97f42
dad8282
d4ff40b
348f72a
d4ff40b
348f72a
c7bc4b8
 
 
 
 
d4ff40b
 
 
 
 
 
6bfa9a4
d4ff40b
348f72a
c7bc4b8
 
 
 
 
 
348f72a
c7bc4b8
 
 
 
 
 
 
 
 
8d03f81
c7bc4b8
 
 
8d03f81
c7bc4b8
be244f0
c7bc4b8
 
 
be244f0
 
 
971e74b
be244f0
d4ff40b
 
 
 
348f72a
d4ff40b
 
 
 
 
 
 
 
c7bc4b8
cd9cb87
d4ff40b
 
 
c7bc4b8
d4ff40b
c7bc4b8
d4ff40b
c7bc4b8
d4ff40b
cd9cb87
d4ff40b
cd9cb87
c7bc4b8
 
 
 
 
 
 
 
 
 
d4ff40b
c7bc4b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
from transformers import pipeline
import textwrap
import numpy as np
import soundfile as sf
import tempfile
import os
from PIL import Image
import string
import re

# Initialize pipelines with caching
@st.cache_resource
def load_pipelines():
    captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large")
    storyer = pipeline("text-generation", model="gpt2")
    tts = pipeline("text-to-speech", model="facebook/mms-tts-eng")
    return captioner, storyer, tts

captioner, storyer, tts = load_pipelines()

# Function to clean raw story
def clean_generated_story(raw):
    allowed_pattern = re.compile(r'[a-zA-Z0-9.,!?"\'-]+\b(?<!\b\w\b)')
    return ' '.join(word for word in re.findall(allowed_pattern, raw) if len(word) > 1)

# Function to generate content from an image
def generate_content(image):
    pil_image = Image.open(image)
    
    # Generate caption
    caption = captioner(pil_image)[0]["generated_text"]
    st.write("**What's in the picture: 🧐**")
    st.write(caption)

    # Create prompt for story
    prompt = (
        f"Write a funny, interesting children's story centered on this scene: {caption}\n"
        f"Story in third-person narrative, describing this scene exactly: {caption} "
        f"Mention the exact place, location, or venue within {caption}. "
        f"Avoid numbers, random letter combinations, and single-letter words.")

    # Generate raw story
    raw = storyer(
        prompt,
        max_new_tokens=100,  
        temperature=0.6,    
        top_p=0.85,          
        no_repeat_ngram_size=0, 
        return_full_text=False
    )[0]["generated_text"].strip()

    # Generate story and audio
    caption, story, audio_path = generate_story(raw, caption, tts)
    return caption, story, audio_path

# Function to generate story and audio
def generate_story(raw, caption, tts):  
    # Clean and trim story
    story = clean_generated_story(raw)
    words = story.split()
    story = " ".join(words[:100])
    
    # Display story in Streamlit
    st.write("** Your funny story: πŸ“**")
    st.write(story)
    
    # Generate audio from cleaned story
    chunks = textwrap.wrap(story, width=200)
    audio = np.concatenate([tts(chunk)["audio"].squeeze() for chunk in chunks])

    # Save audio to temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
        sf.write(temp_file.name, audio, tts.model.config.sampling_rate)
        temp_file_path = temp_file.name

    return caption, story, temp_file_path

# Streamlit UI
st.title("😎 Story Maker")
st.markdown("Upload a picture, I will generate a story for you")

uploaded_image = st.file_uploader("Choose your picture", type=["jpg", "jpeg", "png"])

# Display image
if uploaded_image is None:
    st.image("https://via.placeholder.com/300", caption="Upload your picture here!", use_container_width=True)
else:
    st.image(uploaded_image, caption="Your Picture", use_container_width=True)

if st.button("Generate a story"):
    if uploaded_image is not None:
        with st.spinner("Processing"):
            try:
                caption, story, audio_path = generate_content(uploaded_image)
                st.success("Your story is ready! 😊")
                st.audio(audio_path, format="audio/wav")
                try:
                    os.remove(audio_path)
                except OSError as e:
                    st.warning(f"Failed to delete temporary audio file: {e}")
            except Exception as e:
                st.error(f"Error generating story: {e}")
    else:
        st.warning("Please upload a picture first!")