Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -4,128 +4,110 @@ from transformers import pipeline
|
|
4 |
import os
|
5 |
|
6 |
# --- App Configuration ---
|
7 |
-
|
8 |
-
|
9 |
-
Enter a topic or a
|
10 |
-
|
11 |
-
**
|
|
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
|
14 |
-
#
|
15 |
examples = [
|
16 |
-
["
|
17 |
-
["
|
18 |
-
["
|
19 |
-
["
|
20 |
-
["
|
21 |
-
["Alleviating stress"],
|
22 |
-
["Helping breathing, satisfaction"],
|
23 |
-
["Relieve Stress, Build Support"],
|
24 |
-
["The Relaxation Response"],
|
25 |
-
["Taking Deep Breaths"],
|
26 |
-
["Delete Not Helpful Thoughts"],
|
27 |
-
["Strengthen Helpful Thoughts"],
|
28 |
-
["Reprogram Pain and Stress Reactions"],
|
29 |
-
["How to Sleep Better and Find Joy"],
|
30 |
-
["Yoga for deep sleep"],
|
31 |
-
["Being a Happier and Healthier Person"],
|
32 |
-
["Relieve chronic pain by"],
|
33 |
-
["Use Mindfulness to Affect Well Being"],
|
34 |
-
["Build and Boost Mental Strength"],
|
35 |
-
["Spending Time Outdoors"],
|
36 |
-
["Daily Routine Tasks"],
|
37 |
-
["Eating and Drinking - Find Healthy Nutrition Habits"],
|
38 |
-
["Drinking - Find Reasons and Cut Back or Quit Entirely"],
|
39 |
-
["Feel better each day when you awake by"],
|
40 |
-
["Feel better physically by"],
|
41 |
-
["Practicing mindfulness each day"],
|
42 |
-
["Be happier by"],
|
43 |
-
["Meditation can improve health by"],
|
44 |
-
["Spending time outdoors helps to"],
|
45 |
-
["Stress is relieved by quieting your mind, getting exercise and time with nature"],
|
46 |
-
["Break the cycle of stress and anxiety"],
|
47 |
-
["Feel calm in stressful situations"],
|
48 |
-
["Deal with work pressure by"],
|
49 |
-
["Learn to reduce feelings of being overwhelmed"]
|
50 |
]
|
51 |
|
52 |
# --- Model Initialization ---
|
53 |
-
#
|
54 |
-
#
|
55 |
-
# Install dependencies: pip install gradio transformers torch accelerate
|
56 |
try:
|
57 |
print("Initializing models... This may take several minutes.")
|
58 |
|
59 |
-
#
|
60 |
-
#
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
64 |
|
65 |
-
generator2 = pipeline("text-generation", model="
|
66 |
-
print("
|
67 |
|
68 |
-
generator3 = pipeline("text-generation", model="
|
69 |
-
print("
|
70 |
|
71 |
-
print("All models loaded successfully!
|
72 |
|
73 |
except Exception as e:
|
74 |
-
print(f"Error loading models
|
75 |
-
print("
|
76 |
-
|
|
|
77 |
def failed_generator(prompt, **kwargs):
|
78 |
-
return [{'generated_text': "
|
79 |
generator1 = generator2 = generator3 = failed_generator
|
80 |
|
81 |
|
82 |
# --- App Logic ---
|
83 |
-
def
|
84 |
-
"""Generates text from the three loaded models."""
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
87 |
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
91 |
|
92 |
return out1, out2, out3
|
93 |
|
94 |
# --- Gradio Interface ---
|
95 |
-
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
96 |
-
gr.Markdown(f"<h1 style='text-align: center;'>{
|
97 |
-
gr.Markdown(
|
98 |
|
99 |
with gr.Row():
|
100 |
with gr.Column(scale=1):
|
101 |
input_area = gr.TextArea(
|
102 |
-
lines=
|
103 |
-
label="Your
|
104 |
-
placeholder="e.g., '
|
105 |
)
|
106 |
-
generate_button = gr.Button("
|
107 |
|
108 |
with gr.Column(scale=2):
|
109 |
with gr.Tabs():
|
110 |
-
with gr.TabItem("
|
111 |
-
gen1_output = gr.TextArea(label="
|
112 |
-
with gr.TabItem("
|
113 |
-
gen2_output = gr.TextArea(label="
|
114 |
-
with gr.TabItem("
|
115 |
-
gen3_output = gr.TextArea(label="
|
116 |
|
117 |
gr.Examples(
|
118 |
examples=examples,
|
119 |
inputs=input_area,
|
120 |
-
label="Example
|
121 |
)
|
122 |
|
123 |
generate_button.click(
|
124 |
-
fn=
|
125 |
inputs=input_area,
|
126 |
outputs=[gen1_output, gen2_output, gen3_output],
|
127 |
api_name="generate"
|
128 |
)
|
129 |
|
130 |
if __name__ == "__main__":
|
131 |
-
demo.launch()
|
|
|
4 |
import os
|
5 |
|
6 |
# --- App Configuration ---
|
7 |
+
TITLE = "✍️ AI Story Weaver"
|
8 |
+
DESCRIPTION = """
|
9 |
+
Enter a prompt, a topic, or the beginning of a story, and get three different continuations from powerful open-source AI models.
|
10 |
+
This app uses:
|
11 |
+
- **Mistral-7B-Instruct-v0.2**
|
12 |
+
- **Google's Gemma-7B-IT**
|
13 |
+
- **Meta's Llama-3-8B-Instruct**
|
14 |
+
|
15 |
+
**⚠️ Hardware Warning:** These are very large models. Loading them requires a powerful GPU with significant VRAM (ideally > 24GB).
|
16 |
+
The initial loading process may take several minutes. You will also need to install the `accelerate` library: `pip install accelerate`
|
17 |
"""
|
18 |
|
19 |
+
# --- Example Prompts for Storytelling ---
|
20 |
examples = [
|
21 |
+
["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."],
|
22 |
+
["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."],
|
23 |
+
["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."],
|
24 |
+
["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"],
|
25 |
+
["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
]
|
27 |
|
28 |
# --- Model Initialization ---
|
29 |
+
# This section loads the models. It requires significant hardware resources.
|
30 |
+
# `device_map="auto"` and `torch_dtype="auto"` help manage resources by using available GPUs and half-precision.
|
|
|
31 |
try:
|
32 |
print("Initializing models... This may take several minutes.")
|
33 |
|
34 |
+
# NOTE: For Llama-3, you may need to log in to Hugging Face and accept the license agreement.
|
35 |
+
# from huggingface_hub import login
|
36 |
+
# login("YOUR_HF_TOKEN")
|
37 |
+
|
38 |
+
generator1 = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype="auto", device_map="auto")
|
39 |
+
print("✅ Mistral-7B loaded.")
|
40 |
|
41 |
+
generator2 = pipeline("text-generation", model="google/gemma-7b-it", torch_dtype="auto", device_map="auto")
|
42 |
+
print("✅ Gemma-7B loaded.")
|
43 |
|
44 |
+
generator3 = pipeline("text-generation", model="meta-llama/Llama-3-8B-Instruct", torch_dtype="auto", device_map="auto")
|
45 |
+
print("✅ Llama-3-8B loaded.")
|
46 |
|
47 |
+
print("All models loaded successfully! 🎉")
|
48 |
|
49 |
except Exception as e:
|
50 |
+
print(f"--- 🚨 Error loading models ---")
|
51 |
+
print(f"Error: {e}")
|
52 |
+
print("Please ensure you have 'torch' and 'accelerate' installed, have sufficient VRAM, and are logged into Hugging Face if required.")
|
53 |
+
# Create a dummy function if models fail, so the app can still launch with an error message.
|
54 |
def failed_generator(prompt, **kwargs):
|
55 |
+
return [{'generated_text': "A model failed to load. Please check the console for errors. You may need more VRAM or need to accept model license terms on Hugging Face."}]
|
56 |
generator1 = generator2 = generator3 = failed_generator
|
57 |
|
58 |
|
59 |
# --- App Logic ---
|
60 |
+
def generate_stories(prompt: str) -> tuple[str, str, str]:
|
61 |
+
"""Generates text from the three loaded models based on the user's prompt."""
|
62 |
+
if not prompt:
|
63 |
+
return "Please enter a prompt to start.", "", ""
|
64 |
+
|
65 |
+
# We use 'max_new_tokens' to control the length of the generated story.
|
66 |
+
# Increased to 200 for more substantial story continuations.
|
67 |
+
params = {"max_new_tokens": 200, "do_sample": True, "temperature": 0.7, "top_p": 0.95}
|
68 |
|
69 |
+
# Generate from all three models
|
70 |
+
out1 = generator1(prompt, **params)[0]['generated_text']
|
71 |
+
out2 = generator2(prompt, **params)[0]['generated_text']
|
72 |
+
out3 = generator3(prompt, **params)[0]['generated_text']
|
73 |
|
74 |
return out1, out2, out3
|
75 |
|
76 |
# --- Gradio Interface ---
|
77 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
|
78 |
+
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
|
79 |
+
gr.Markdown(DESCRIPTION)
|
80 |
|
81 |
with gr.Row():
|
82 |
with gr.Column(scale=1):
|
83 |
input_area = gr.TextArea(
|
84 |
+
lines=5,
|
85 |
+
label="Your Story Prompt 👇",
|
86 |
+
placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'"
|
87 |
)
|
88 |
+
generate_button = gr.Button("Weave a Story ✨", variant="primary")
|
89 |
|
90 |
with gr.Column(scale=2):
|
91 |
with gr.Tabs():
|
92 |
+
with gr.TabItem("Mistral-7B"):
|
93 |
+
gen1_output = gr.TextArea(label="Mistral's Tale", interactive=False, lines=12)
|
94 |
+
with gr.TabItem("Gemma-7B"):
|
95 |
+
gen2_output = gr.TextArea(label="Gemma's Chronicle", interactive=False, lines=12)
|
96 |
+
with gr.TabItem("Llama-3-8B"):
|
97 |
+
gen3_output = gr.TextArea(label="Llama's Legend", interactive=False, lines=12)
|
98 |
|
99 |
gr.Examples(
|
100 |
examples=examples,
|
101 |
inputs=input_area,
|
102 |
+
label="Example Story Starters (Click to use)"
|
103 |
)
|
104 |
|
105 |
generate_button.click(
|
106 |
+
fn=generate_stories,
|
107 |
inputs=input_area,
|
108 |
outputs=[gen1_output, gen2_output, gen3_output],
|
109 |
api_name="generate"
|
110 |
)
|
111 |
|
112 |
if __name__ == "__main__":
|
113 |
+
demo.launch()
|