awacke1's picture
Update app.py
1084118 verified
raw
history blame
5.32 kB
import gradio as gr
import torch
from transformers import pipeline
import os
# --- App Configuration ---
TITLE = "✍️ AI Story Weaver"
DESCRIPTION = """
Enter a prompt, a topic, or the beginning of a story, and get three different continuations from powerful open-source AI models.
This app uses:
- **Mistral-7B-Instruct-v0.2**
- **Google's Gemma-7B-IT**
- **Meta's Llama-3-8B-Instruct**
**⚠️ Hardware Warning:** These are very large models. Loading them requires a powerful GPU with significant VRAM (ideally > 24GB).
The initial loading process may take several minutes. You will also need to install the `accelerate` library: `pip install accelerate`
"""
# --- Example Prompts for Storytelling ---
examples = [
["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."],
["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."],
["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."],
["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"],
["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."]
]
# --- Model Initialization ---
# This section loads the models. It requires significant hardware resources.
# `device_map="auto"` and `torch_dtype="auto"` help manage resources by using available GPUs and half-precision.
try:
print("Initializing models... This may take several minutes.")
# NOTE: For Llama-3, you may need to log in to Hugging Face and accept the license agreement.
# from huggingface_hub import login
# login("YOUR_HF_TOKEN")
generator1 = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype="auto", device_map="auto")
print("✅ Mistral-7B loaded.")
generator2 = pipeline("text-generation", model="google/gemma-7b-it", torch_dtype="auto", device_map="auto")
print("✅ Gemma-7B loaded.")
generator3 = pipeline("text-generation", model="meta-llama/Llama-3-8B-Instruct", torch_dtype="auto", device_map="auto")
print("✅ Llama-3-8B loaded.")
print("All models loaded successfully! 🎉")
except Exception as e:
print(f"--- 🚨 Error loading models ---")
print(f"Error: {e}")
print("Please ensure you have 'torch' and 'accelerate' installed, have sufficient VRAM, and are logged into Hugging Face if required.")
# Create a dummy function if models fail, so the app can still launch with an error message.
def failed_generator(prompt, **kwargs):
return [{'generated_text': "A model failed to load. Please check the console for errors. You may need more VRAM or need to accept model license terms on Hugging Face."}]
generator1 = generator2 = generator3 = failed_generator
# --- App Logic ---
def generate_stories(prompt: str) -> tuple[str, str, str]:
"""Generates text from the three loaded models based on the user's prompt."""
if not prompt:
return "Please enter a prompt to start.", "", ""
# We use 'max_new_tokens' to control the length of the generated story.
# Increased to 200 for more substantial story continuations.
params = {"max_new_tokens": 200, "do_sample": True, "temperature": 0.7, "top_p": 0.95}
# Generate from all three models
out1 = generator1(prompt, **params)[0]['generated_text']
out2 = generator2(prompt, **params)[0]['generated_text']
out3 = generator3(prompt, **params)[0]['generated_text']
return out1, out2, out3
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_area = gr.TextArea(
lines=5,
label="Your Story Prompt 👇",
placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'"
)
generate_button = gr.Button("Weave a Story ✨", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("Mistral-7B"):
gen1_output = gr.TextArea(label="Mistral's Tale", interactive=False, lines=12)
with gr.TabItem("Gemma-7B"):
gen2_output = gr.TextArea(label="Gemma's Chronicle", interactive=False, lines=12)
with gr.TabItem("Llama-3-8B"):
gen3_output = gr.TextArea(label="Llama's Legend", interactive=False, lines=12)
gr.Examples(
examples=examples,
inputs=input_area,
label="Example Story Starters (Click to use)"
)
generate_button.click(
fn=generate_stories,
inputs=input_area,
outputs=[gen1_output, gen2_output, gen3_output],
api_name="generate"
)
if __name__ == "__main__":
demo.launch()