Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
# Load the Image-to-Text (OCR) model | |
ocr_model = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
# Load the Text Generation model | |
story_model_name = "EleutherAI/gpt-neo-2.7B" | |
story_tokenizer = AutoTokenizer.from_pretrained(story_model_name) | |
story_model = AutoModelForCausalLM.from_pretrained(story_model_name) | |
# Function to extract text description from an image | |
def extract_description(image_array): | |
try: | |
# Convert the NumPy array to a PIL image | |
image = Image.fromarray(image_array) | |
# Use the OCR model to extract a caption/description from the image | |
result = ocr_model(image) | |
return result[0]["generated_text"] | |
except Exception as e: | |
return f"Error extracting description: {e}" | |
# Function to generate a story based on the extracted description | |
def generate_story(description): | |
try: | |
# Format the input prompt for the story | |
prompt = f"Create a creative story based on this description: {description}" | |
# Use the story model to generate text | |
inputs = story_tokenizer(prompt, return_tensors="pt", truncation=True) | |
outputs = story_model.generate(inputs["input_ids"], max_length=300, num_return_sequences=1, temperature=0.8) | |
story = story_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return story | |
except Exception as e: | |
return f"Error generating story: {e}" | |
# Main function to process the image and generate a story | |
def create_story(image): | |
try: | |
# Step 1: Extract a description from the image | |
description = extract_description(image) | |
if not description or "Error" in description: | |
return description | |
# Step 2: Generate a story from the extracted description | |
story = generate_story(description) | |
# Combine the description and story for the output | |
output = f"π· Extracted Description:\n{description}\n\nπ Generated Story:\n{story}" | |
return output | |
except Exception as e: | |
return f"Error processing the image: {e}" | |
# Gradio interface | |
interface = gr.Interface( | |
fn=create_story, | |
inputs=gr.Image(label="Upload an Image (PNG, JPG, JPEG)"), | |
outputs=gr.Textbox(label="Generated Story"), | |
title="Text-Based Story Creator", | |
description=( | |
"Upload an image, and this app will generate a creative story based on the description of the image. " | |
"It uses advanced AI models for image-to-text conversion and story generation." | |
), | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
interface.launch() | |