import gradio as gr import torch from transformers import pipeline import os # --- App Configuration --- TITLE = "✍️ AI Story Outliner" DESCRIPTION = """ Enter a prompt and get 10 unique story outlines from a CPU-friendly AI model. The app uses **TinyLlama-1.1B** to generate creative outlines formatted in Markdown. **How it works:** 1. Enter your story idea. 2. The AI will generate 10 different story outlines. 3. Each outline has a dramatic beginning and is concise, like a song. """ # --- Example Prompts for Storytelling --- examples = [ ["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."], ["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."], ["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."], ["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"], ["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."] ] # --- Model Initialization --- # This section loads a smaller, CPU-friendly model. # It will automatically use the HF_TOKEN secret when deployed on Hugging Face Spaces. try: print("Initializing model... This may take a moment.") # Load the token from environment variables if it exists (for HF Spaces secrets) hf_token = os.environ.get("HF_TOKEN", None) # Using a smaller model that is more suitable for running without a high-end GPU. generator = pipeline( "text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.bfloat16, # More efficient dtype device_map="auto", # Will use GPU if available, otherwise CPU token=hf_token ) print("✅ TinyLlama model loaded successfully!") except Exception as e: print(f"--- 🚨 Error loading models ---") print(f"Error: {e}") # Create a dummy function if models fail, so the app can still launch with an error message. def failed_generator(prompt, **kwargs): error_message = f"Model failed to load. Please check the console for errors. Error: {e}" return [{'generated_text': error_message}] generator = failed_generator # --- App Logic --- def generate_stories(prompt: str) -> list[str]: """ Generates 10 story outlines from the loaded model based on the user's prompt. """ if not prompt: # Return a list of 10 empty strings to clear the outputs return [""] * 10 # A detailed system prompt to guide the model's output format and structure. system_prompt = f""" <|system|> You are an expert storyteller. Your task is to take a user's prompt and write a short story as a Markdown outline. The story must have a dramatic arc and be the length of a song. Use emojis to highlight the story sections. **Your Story Outline Structure:** - 🎬 **The Hook:** A dramatic opening. - 🎼 **The Ballad:** The main story, told concisely. - 🔚 **The Finale:** A clear and satisfying ending. <|user|> {prompt} <|assistant|> """ # Parameters for the pipeline to generate 10 diverse results. params = { "max_new_tokens": 250, "num_return_sequences": 10, "do_sample": True, "temperature": 0.8, "top_k": 50, "top_p": 0.95, } # Generate 10 different story variations outputs = generator(system_prompt, **params) # Extract the generated text and clean it up. stories = [] for out in outputs: # Remove the system prompt from the beginning of the output cleaned_text = out['generated_text'].replace(system_prompt, "").strip() stories.append(cleaned_text) # Ensure we return exactly 10 stories, padding with an error message if necessary. while len(stories) < 10: stories.append("Failed to generate a story for this slot.") return stories # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo: gr.Markdown(f"

{TITLE}

") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): input_area = gr.TextArea( lines=5, label="Your Story Prompt 👇", placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'" ) generate_button = gr.Button("Generate 10 Outlines ✨", variant="primary") gr.Markdown("---") gr.Markdown("## 📖 Your 10 Story Outlines") # Create 10 markdown components to display the stories in two columns story_outputs = [] with gr.Row(): with gr.Column(): for i in range(5): md = gr.Markdown(label=f"Story Outline {i + 1}") story_outputs.append(md) with gr.Column(): for i in range(5, 10): md = gr.Markdown(label=f"Story Outline {i + 1}") story_outputs.append(md) gr.Examples( examples=examples, inputs=input_area, label="Example Story Starters (Click to use)" ) generate_button.click( fn=generate_stories, inputs=input_area, outputs=story_outputs, api_name="generate" ) if __name__ == "__main__": demo.launch()