awacke1's picture
Update app.py
f873e60 verified
raw
history blame
5.69 kB
import gradio as gr
import torch
from transformers import pipeline
import os
# --- App Configuration ---
TITLE = "✍️ AI Story Outliner"
DESCRIPTION = """
Enter a prompt and get 10 unique story outlines from a CPU-friendly AI model.
The app uses **TinyLlama-1.1B** to generate creative outlines formatted in Markdown.
**How it works:**
1. Enter your story idea.
2. The AI will generate 10 different story outlines.
3. Each outline has a dramatic beginning and is concise, like a song.
"""
# --- Example Prompts for Storytelling ---
examples = [
["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."],
["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."],
["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."],
["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"],
["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."]
]
# --- Model Initialization ---
# This section loads a smaller, CPU-friendly model.
# It will automatically use the HF_TOKEN secret when deployed on Hugging Face Spaces.
try:
print("Initializing model... This may take a moment.")
# Load the token from environment variables if it exists (for HF Spaces secrets)
hf_token = os.environ.get("HF_TOKEN", None)
# Using a smaller model that is more suitable for running without a high-end GPU.
generator = pipeline(
"text-generation",
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.bfloat16, # More efficient dtype
device_map="auto", # Will use GPU if available, otherwise CPU
token=hf_token
)
print("✅ TinyLlama model loaded successfully!")
except Exception as e:
print(f"--- 🚨 Error loading models ---")
print(f"Error: {e}")
# Create a dummy function if models fail, so the app can still launch with an error message.
def failed_generator(prompt, **kwargs):
error_message = f"Model failed to load. Please check the console for errors. Error: {e}"
return [{'generated_text': error_message}]
generator = failed_generator
# --- App Logic ---
def generate_stories(prompt: str) -> list[str]:
"""
Generates 10 story outlines from the loaded model based on the user's prompt.
"""
if not prompt:
# Return a list of 10 empty strings to clear the outputs
return [""] * 10
# A detailed system prompt to guide the model's output format and structure.
system_prompt = f"""
<|system|>
You are an expert storyteller. Your task is to take a user's prompt and write
a short story as a Markdown outline. The story must have a dramatic arc and be
the length of a song. Use emojis to highlight the story sections.
**Your Story Outline Structure:**
- 🎬 **The Hook:** A dramatic opening.
- 🎼 **The Ballad:** The main story, told concisely.
- 🔚 **The Finale:** A clear and satisfying ending.</s>
<|user|>
{prompt}</s>
<|assistant|>
"""
# Parameters for the pipeline to generate 10 diverse results.
params = {
"max_new_tokens": 250,
"num_return_sequences": 10,
"do_sample": True,
"temperature": 0.8,
"top_k": 50,
"top_p": 0.95,
}
# Generate 10 different story variations
outputs = generator(system_prompt, **params)
# Extract the generated text and clean it up.
stories = []
for out in outputs:
# Remove the system prompt from the beginning of the output
cleaned_text = out['generated_text'].replace(system_prompt, "").strip()
stories.append(cleaned_text)
# Ensure we return exactly 10 stories, padding with an error message if necessary.
while len(stories) < 10:
stories.append("Failed to generate a story for this slot.")
return stories
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_area = gr.TextArea(
lines=5,
label="Your Story Prompt 👇",
placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'"
)
generate_button = gr.Button("Generate 10 Outlines ✨", variant="primary")
gr.Markdown("---")
gr.Markdown("## 📖 Your 10 Story Outlines")
# Create 10 markdown components to display the stories in two columns
story_outputs = []
with gr.Row():
with gr.Column():
for i in range(5):
md = gr.Markdown(label=f"Story Outline {i + 1}")
story_outputs.append(md)
with gr.Column():
for i in range(5, 10):
md = gr.Markdown(label=f"Story Outline {i + 1}")
story_outputs.append(md)
gr.Examples(
examples=examples,
inputs=input_area,
label="Example Story Starters (Click to use)"
)
generate_button.click(
fn=generate_stories,
inputs=input_area,
outputs=story_outputs,
api_name="generate"
)
if __name__ == "__main__":
demo.launch()