Story_teller / app.py
Mr-Vicky-01's picture
Create app.py
6364e8b verified
raw
history blame
2.14 kB
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from langchain.llms import GooglePalm
from langchain import LLMChain, PromptTemplate
from gtts import gTTS
from IPython.display import Audio
import gradio as gr
import numpy as np
import os
# Load image captioning model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
def generate_caption_from_image(image_path):
# Process the image and generate caption
raw_image = Image.open(image_path).convert("RGB")
inputs = processor(raw_image, return_tensors="pt")
out = model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
def generate_story_from_caption(caption):
# Generate story based on caption
api_key = os.getenv("GOOGLE_API")
prompt_template = """You are a story teller;
You can generate a short story based on a simple narrative, the story should between 30 to 50 words;
CONTEXT: {scenario}
Story: """
PROMPT = PromptTemplate(template=prompt_template, input_variables=["scenario"])
llm_chain = LLMChain(prompt=PROMPT,
llm=GooglePalm(google_api_key=api_key, temperature=0.8))
scenario = caption
story = llm_chain.run(scenario)
return story
def text_to_speech(text):
# Convert text to speech
tts = gTTS(text=text, lang='en')
tts.save("output.mp3")
return "output.mp3"
def generate_story_from_image(image_input):
input_image = Image.fromarray(image_input)
input_image.save("input_image.jpg")
image_path = 'input_image.jpg'
caption = generate_caption_from_image(image_path)
story = generate_story_from_caption(caption)
audio = text_to_speech(story)
return audio
# Define the input and output components
inputs = gr.Image(label="Image")
outputs = gr.Audio(label="Story Audio")
# Create the Gradio interface
gr.Interface(fn=generate_story_from_image, inputs=inputs, outputs=outputs, title="Story Teller").launch(debug=True,share=True)