import gradio as gr import torch from transformers import pipeline import os # --- App Configuration --- TITLE = "✍️ AI Story Weaver" DESCRIPTION = """ Enter a prompt, a topic, or the beginning of a story, and get three different continuations from powerful open-source AI models. This app uses: - **Mistral-7B-Instruct-v0.2** - **Google's Gemma-7B-IT** - **Meta's Llama-3-8B-Instruct** **⚠️ Hardware Warning:** These are very large models. Loading them requires a powerful GPU with significant VRAM (ideally > 24GB). The initial loading process may take several minutes. You will also need to install the `accelerate` library: `pip install accelerate` """ # --- Example Prompts for Storytelling --- examples = [ ["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."], ["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."], ["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."], ["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"], ["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."] ] # --- Model Initialization --- # This section loads the models. It requires significant hardware resources. # `device_map="auto"` and `torch_dtype="auto"` help manage resources by using available GPUs and half-precision. try: print("Initializing models... This may take several minutes.") # NOTE: For Llama-3, you may need to log in to Hugging Face and accept the license agreement. # from huggingface_hub import login # login("YOUR_HF_TOKEN") generator1 = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype="auto", device_map="auto") print("✅ Mistral-7B loaded.") generator2 = pipeline("text-generation", model="google/gemma-7b-it", torch_dtype="auto", device_map="auto") print("✅ Gemma-7B loaded.") generator3 = pipeline("text-generation", model="meta-llama/Llama-3-8B-Instruct", torch_dtype="auto", device_map="auto") print("✅ Llama-3-8B loaded.") print("All models loaded successfully! 🎉") except Exception as e: print(f"--- 🚨 Error loading models ---") print(f"Error: {e}") print("Please ensure you have 'torch' and 'accelerate' installed, have sufficient VRAM, and are logged into Hugging Face if required.") # Create a dummy function if models fail, so the app can still launch with an error message. def failed_generator(prompt, **kwargs): return [{'generated_text': "A model failed to load. Please check the console for errors. You may need more VRAM or need to accept model license terms on Hugging Face."}] generator1 = generator2 = generator3 = failed_generator # --- App Logic --- def generate_stories(prompt: str) -> tuple[str, str, str]: """Generates text from the three loaded models based on the user's prompt.""" if not prompt: return "Please enter a prompt to start.", "", "" # We use 'max_new_tokens' to control the length of the generated story. # Increased to 200 for more substantial story continuations. params = {"max_new_tokens": 200, "do_sample": True, "temperature": 0.7, "top_p": 0.95} # Generate from all three models out1 = generator1(prompt, **params)[0]['generated_text'] out2 = generator2(prompt, **params)[0]['generated_text'] out3 = generator3(prompt, **params)[0]['generated_text'] return out1, out2, out3 # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo: gr.Markdown(f"

{TITLE}

") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): input_area = gr.TextArea( lines=5, label="Your Story Prompt 👇", placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'" ) generate_button = gr.Button("Weave a Story ✨", variant="primary") with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("Mistral-7B"): gen1_output = gr.TextArea(label="Mistral's Tale", interactive=False, lines=12) with gr.TabItem("Gemma-7B"): gen2_output = gr.TextArea(label="Gemma's Chronicle", interactive=False, lines=12) with gr.TabItem("Llama-3-8B"): gen3_output = gr.TextArea(label="Llama's Legend", interactive=False, lines=12) gr.Examples( examples=examples, inputs=input_area, label="Example Story Starters (Click to use)" ) generate_button.click( fn=generate_stories, inputs=input_area, outputs=[gen1_output, gen2_output, gen3_output], api_name="generate" ) if __name__ == "__main__": demo.launch()