Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import pipeline | |
import os | |
# --- App Configuration --- | |
TITLE = "✍️ AI Story Weaver" | |
DESCRIPTION = """ | |
Enter a prompt, a topic, or the beginning of a story, and get three different continuations from powerful open-source AI models. | |
This app uses: | |
- **Mistral-7B-Instruct-v0.2** | |
- **Google's Gemma-7B-IT** | |
- **Meta's Llama-3-8B-Instruct** | |
**⚠️ Hardware Warning:** These are very large models. Loading them requires a powerful GPU with significant VRAM (ideally > 24GB). | |
The initial loading process may take several minutes. You will also need to install the `accelerate` library: `pip install accelerate` | |
""" | |
# --- Example Prompts for Storytelling --- | |
examples = [ | |
["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."], | |
["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."], | |
["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."], | |
["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"], | |
["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."] | |
] | |
# --- Model Initialization --- | |
# This section loads the models. It requires significant hardware resources. | |
# `device_map="auto"` and `torch_dtype="auto"` help manage resources by using available GPUs and half-precision. | |
try: | |
print("Initializing models... This may take several minutes.") | |
# NOTE: For Llama-3, you may need to log in to Hugging Face and accept the license agreement. | |
# from huggingface_hub import login | |
# login("YOUR_HF_TOKEN") | |
generator1 = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype="auto", device_map="auto") | |
print("✅ Mistral-7B loaded.") | |
generator2 = pipeline("text-generation", model="google/gemma-7b-it", torch_dtype="auto", device_map="auto") | |
print("✅ Gemma-7B loaded.") | |
generator3 = pipeline("text-generation", model="meta-llama/Llama-3-8B-Instruct", torch_dtype="auto", device_map="auto") | |
print("✅ Llama-3-8B loaded.") | |
print("All models loaded successfully! 🎉") | |
except Exception as e: | |
print(f"--- 🚨 Error loading models ---") | |
print(f"Error: {e}") | |
print("Please ensure you have 'torch' and 'accelerate' installed, have sufficient VRAM, and are logged into Hugging Face if required.") | |
# Create a dummy function if models fail, so the app can still launch with an error message. | |
def failed_generator(prompt, **kwargs): | |
return [{'generated_text': "A model failed to load. Please check the console for errors. You may need more VRAM or need to accept model license terms on Hugging Face."}] | |
generator1 = generator2 = generator3 = failed_generator | |
# --- App Logic --- | |
def generate_stories(prompt: str) -> tuple[str, str, str]: | |
"""Generates text from the three loaded models based on the user's prompt.""" | |
if not prompt: | |
return "Please enter a prompt to start.", "", "" | |
# We use 'max_new_tokens' to control the length of the generated story. | |
# Increased to 200 for more substantial story continuations. | |
params = {"max_new_tokens": 200, "do_sample": True, "temperature": 0.7, "top_p": 0.95} | |
# Generate from all three models | |
out1 = generator1(prompt, **params)[0]['generated_text'] | |
out2 = generator2(prompt, **params)[0]['generated_text'] | |
out3 = generator3(prompt, **params)[0]['generated_text'] | |
return out1, out2, out3 | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo: | |
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>") | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_area = gr.TextArea( | |
lines=5, | |
label="Your Story Prompt 👇", | |
placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'" | |
) | |
generate_button = gr.Button("Weave a Story ✨", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("Mistral-7B"): | |
gen1_output = gr.TextArea(label="Mistral's Tale", interactive=False, lines=12) | |
with gr.TabItem("Gemma-7B"): | |
gen2_output = gr.TextArea(label="Gemma's Chronicle", interactive=False, lines=12) | |
with gr.TabItem("Llama-3-8B"): | |
gen3_output = gr.TextArea(label="Llama's Legend", interactive=False, lines=12) | |
gr.Examples( | |
examples=examples, | |
inputs=input_area, | |
label="Example Story Starters (Click to use)" | |
) | |
generate_button.click( | |
fn=generate_stories, | |
inputs=input_area, | |
outputs=[gen1_output, gen2_output, gen3_output], | |
api_name="generate" | |
) | |
if __name__ == "__main__": | |
demo.launch() | |