Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import pipeline | |
import os | |
# --- App Configuration --- | |
TITLE = "✍️ AI Story Outliner" | |
DESCRIPTION = """ | |
Enter a prompt and get 10 unique story outlines from a CPU-friendly AI model. | |
The app uses **TinyLlama-1.1B** to generate creative outlines formatted in Markdown. | |
**How it works:** | |
1. Enter your story idea. | |
2. The AI will generate 10 different story outlines. | |
3. Each outline has a dramatic beginning and is concise, like a song. | |
""" | |
# --- Example Prompts for Storytelling --- | |
examples = [ | |
["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."], | |
["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."], | |
["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."], | |
["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"], | |
["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."] | |
] | |
# --- Model Initialization --- | |
# This section loads a smaller, CPU-friendly model. | |
# It will automatically use the HF_TOKEN secret when deployed on Hugging Face Spaces. | |
try: | |
print("Initializing model... This may take a moment.") | |
# Load the token from environment variables if it exists (for HF Spaces secrets) | |
hf_token = os.environ.get("HF_TOKEN", None) | |
# Using a smaller model that is more suitable for running without a high-end GPU. | |
generator = pipeline( | |
"text-generation", | |
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
torch_dtype=torch.bfloat16, # More efficient dtype | |
device_map="auto", # Will use GPU if available, otherwise CPU | |
token=hf_token | |
) | |
print("✅ TinyLlama model loaded successfully!") | |
except Exception as e: | |
print(f"--- 🚨 Error loading models ---") | |
print(f"Error: {e}") | |
# Create a dummy function if models fail, so the app can still launch with an error message. | |
def failed_generator(prompt, **kwargs): | |
error_message = f"Model failed to load. Please check the console for errors. Error: {e}" | |
return [{'generated_text': error_message}] | |
generator = failed_generator | |
# --- App Logic --- | |
def generate_stories(prompt: str) -> list[str]: | |
""" | |
Generates 10 story outlines from the loaded model based on the user's prompt. | |
""" | |
if not prompt: | |
# Return a list of 10 empty strings to clear the outputs | |
return [""] * 10 | |
# A detailed system prompt to guide the model's output format and structure. | |
system_prompt = f""" | |
<|system|> | |
You are an expert storyteller. Your task is to take a user's prompt and write | |
a short story as a Markdown outline. The story must have a dramatic arc and be | |
the length of a song. Use emojis to highlight the story sections. | |
**Your Story Outline Structure:** | |
- 🎬 **The Hook:** A dramatic opening. | |
- 🎼 **The Ballad:** The main story, told concisely. | |
- 🔚 **The Finale:** A clear and satisfying ending.</s> | |
<|user|> | |
{prompt}</s> | |
<|assistant|> | |
""" | |
# Parameters for the pipeline to generate 10 diverse results. | |
params = { | |
"max_new_tokens": 250, | |
"num_return_sequences": 10, | |
"do_sample": True, | |
"temperature": 0.8, | |
"top_k": 50, | |
"top_p": 0.95, | |
} | |
# Generate 10 different story variations | |
outputs = generator(system_prompt, **params) | |
# Extract the generated text and clean it up. | |
stories = [] | |
for out in outputs: | |
# Remove the system prompt from the beginning of the output | |
cleaned_text = out['generated_text'].replace(system_prompt, "").strip() | |
stories.append(cleaned_text) | |
# Ensure we return exactly 10 stories, padding with an error message if necessary. | |
while len(stories) < 10: | |
stories.append("Failed to generate a story for this slot.") | |
return stories | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo: | |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>") | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_area = gr.TextArea( | |
lines=5, | |
label="Your Story Prompt 👇", | |
placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'" | |
) | |
generate_button = gr.Button("Generate 10 Outlines ✨", variant="primary") | |
gr.Markdown("---") | |
gr.Markdown("## 📖 Your 10 Story Outlines") | |
# Create 10 markdown components to display the stories in two columns | |
story_outputs = [] | |
with gr.Row(): | |
with gr.Column(): | |
for i in range(5): | |
md = gr.Markdown(label=f"Story Outline {i + 1}") | |
story_outputs.append(md) | |
with gr.Column(): | |
for i in range(5, 10): | |
md = gr.Markdown(label=f"Story Outline {i + 1}") | |
story_outputs.append(md) | |
gr.Examples( | |
examples=examples, | |
inputs=input_area, | |
label="Example Story Starters (Click to use)" | |
) | |
generate_button.click( | |
fn=generate_stories, | |
inputs=input_area, | |
outputs=story_outputs, | |
api_name="generate" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |