awacke1 commited on
Commit
1084118
·
verified ·
1 Parent(s): f941deb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -80
app.py CHANGED
@@ -4,128 +4,110 @@ from transformers import pipeline
4
  import os
5
 
6
  # --- App Configuration ---
7
- title = "📗 Health and Mindful Story Gen ❤️"
8
- description = """
9
- Enter a topic or a starting sentence related to health, mindfulness, or well-being.
10
- The app will generate continuations from three different language models.
11
- **Note:** These models are very large. Initial loading and first-time generation may take a few minutes.
 
 
 
 
 
12
  """
13
 
14
- # Define the examples to show in the interface
15
  examples = [
16
- ["Mental Body Scan"],
17
- ["Stretch, Calm, Breath"],
18
- ["Relaxed Seat Breath"],
19
- ["Walk Feel"],
20
- ["Brain gamification"],
21
- ["Alleviating stress"],
22
- ["Helping breathing, satisfaction"],
23
- ["Relieve Stress, Build Support"],
24
- ["The Relaxation Response"],
25
- ["Taking Deep Breaths"],
26
- ["Delete Not Helpful Thoughts"],
27
- ["Strengthen Helpful Thoughts"],
28
- ["Reprogram Pain and Stress Reactions"],
29
- ["How to Sleep Better and Find Joy"],
30
- ["Yoga for deep sleep"],
31
- ["Being a Happier and Healthier Person"],
32
- ["Relieve chronic pain by"],
33
- ["Use Mindfulness to Affect Well Being"],
34
- ["Build and Boost Mental Strength"],
35
- ["Spending Time Outdoors"],
36
- ["Daily Routine Tasks"],
37
- ["Eating and Drinking - Find Healthy Nutrition Habits"],
38
- ["Drinking - Find Reasons and Cut Back or Quit Entirely"],
39
- ["Feel better each day when you awake by"],
40
- ["Feel better physically by"],
41
- ["Practicing mindfulness each day"],
42
- ["Be happier by"],
43
- ["Meditation can improve health by"],
44
- ["Spending time outdoors helps to"],
45
- ["Stress is relieved by quieting your mind, getting exercise and time with nature"],
46
- ["Break the cycle of stress and anxiety"],
47
- ["Feel calm in stressful situations"],
48
- ["Deal with work pressure by"],
49
- ["Learn to reduce feelings of being overwhelmed"]
50
  ]
51
 
52
  # --- Model Initialization ---
53
- # WARNING: Loading these models requires significant hardware (ideally a GPU with >24GB VRAM).
54
- # 'device_map="auto"' and 'torch_dtype' require the 'accelerate' library.
55
- # Install dependencies: pip install gradio transformers torch accelerate
56
  try:
57
  print("Initializing models... This may take several minutes.")
58
 
59
- # Using device_map="auto" to automatically use available GPUs.
60
- # Using torch_dtype="auto" to load models in half-precision (float16/bfloat16) to save memory.
61
-
62
- generator1 = pipeline("text-generation", model="gpt2-large", device_map="auto")
63
- print("GPT-2 Large loaded.")
 
64
 
65
- generator2 = pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B", torch_dtype="auto", device_map="auto")
66
- print("GPT-Neo 2.7B loaded.")
67
 
68
- generator3 = pipeline("text-generation", model="EleutherAI/gpt-j-6B", torch_dtype="auto", device_map="auto")
69
- print("GPT-J 6B loaded.")
70
 
71
- print("All models loaded successfully! ")
72
 
73
  except Exception as e:
74
- print(f"Error loading models: {e}")
75
- print("Please ensure you have 'torch' and 'accelerate' installed and have sufficient VRAM.")
76
- # Create dummy functions if models fail to load, so the app can still launch.
 
77
  def failed_generator(prompt, **kwargs):
78
- return [{'generated_text': "Model failed to load. Check console for errors."}]
79
  generator1 = generator2 = generator3 = failed_generator
80
 
81
 
82
  # --- App Logic ---
83
- def generate_outputs(input_text: str) -> tuple[str, str, str]:
84
- """Generates text from the three loaded models."""
85
- # Using 'max_new_tokens' is preferred over 'max_length' to specify the length of the generated text only.
86
- params = {"max_new_tokens": 60, "num_return_sequences": 1}
 
 
 
 
87
 
88
- out1 = generator1(input_text, **params)[0]['generated_text']
89
- out2 = generator2(input_text, **params)[0]['generated_text']
90
- out3 = generator3(input_text, **params)[0]['generated_text']
 
91
 
92
  return out1, out2, out3
93
 
94
  # --- Gradio Interface ---
95
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
96
- gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
97
- gr.Markdown(description)
98
 
99
  with gr.Row():
100
  with gr.Column(scale=1):
101
  input_area = gr.TextArea(
102
- lines=3,
103
- label="Your starting prompt 👇",
104
- placeholder="e.g., 'To relieve stress, I will try...'"
105
  )
106
- generate_button = gr.Button("Generate ✨", variant="primary")
107
 
108
  with gr.Column(scale=2):
109
  with gr.Tabs():
110
- with gr.TabItem("GPT-2 Large"):
111
- gen1_output = gr.TextArea(label="GPT-2 Large Output", interactive=False, lines=7)
112
- with gr.TabItem("GPT-Neo 2.7B"):
113
- gen2_output = gr.TextArea(label="GPT-Neo 2.7B Output", interactive=False, lines=7)
114
- with gr.TabItem("GPT-J 6B"):
115
- gen3_output = gr.TextArea(label="GPT-J 6B Output", interactive=False, lines=7)
116
 
117
  gr.Examples(
118
  examples=examples,
119
  inputs=input_area,
120
- label="Example Prompts (Click to use)"
121
  )
122
 
123
  generate_button.click(
124
- fn=generate_outputs,
125
  inputs=input_area,
126
  outputs=[gen1_output, gen2_output, gen3_output],
127
  api_name="generate"
128
  )
129
 
130
  if __name__ == "__main__":
131
- demo.launch()
 
4
  import os
5
 
6
  # --- App Configuration ---
7
+ TITLE = "✍️ AI Story Weaver"
8
+ DESCRIPTION = """
9
+ Enter a prompt, a topic, or the beginning of a story, and get three different continuations from powerful open-source AI models.
10
+ This app uses:
11
+ - **Mistral-7B-Instruct-v0.2**
12
+ - **Google's Gemma-7B-IT**
13
+ - **Meta's Llama-3-8B-Instruct**
14
+
15
+ **⚠️ Hardware Warning:** These are very large models. Loading them requires a powerful GPU with significant VRAM (ideally > 24GB).
16
+ The initial loading process may take several minutes. You will also need to install the `accelerate` library: `pip install accelerate`
17
  """
18
 
19
+ # --- Example Prompts for Storytelling ---
20
  examples = [
21
+ ["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."],
22
+ ["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."],
23
+ ["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but not for what the dragon said when it finally spoke."],
24
+ ["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"],
25
+ ["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  ]
27
 
28
  # --- Model Initialization ---
29
+ # This section loads the models. It requires significant hardware resources.
30
+ # `device_map="auto"` and `torch_dtype="auto"` help manage resources by using available GPUs and half-precision.
 
31
  try:
32
  print("Initializing models... This may take several minutes.")
33
 
34
+ # NOTE: For Llama-3, you may need to log in to Hugging Face and accept the license agreement.
35
+ # from huggingface_hub import login
36
+ # login("YOUR_HF_TOKEN")
37
+
38
+ generator1 = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.2", torch_dtype="auto", device_map="auto")
39
+ print("✅ Mistral-7B loaded.")
40
 
41
+ generator2 = pipeline("text-generation", model="google/gemma-7b-it", torch_dtype="auto", device_map="auto")
42
+ print("✅ Gemma-7B loaded.")
43
 
44
+ generator3 = pipeline("text-generation", model="meta-llama/Llama-3-8B-Instruct", torch_dtype="auto", device_map="auto")
45
+ print("✅ Llama-3-8B loaded.")
46
 
47
+ print("All models loaded successfully! 🎉")
48
 
49
  except Exception as e:
50
+ print(f"--- 🚨 Error loading models ---")
51
+ print(f"Error: {e}")
52
+ print("Please ensure you have 'torch' and 'accelerate' installed, have sufficient VRAM, and are logged into Hugging Face if required.")
53
+ # Create a dummy function if models fail, so the app can still launch with an error message.
54
  def failed_generator(prompt, **kwargs):
55
+ return [{'generated_text': "A model failed to load. Please check the console for errors. You may need more VRAM or need to accept model license terms on Hugging Face."}]
56
  generator1 = generator2 = generator3 = failed_generator
57
 
58
 
59
  # --- App Logic ---
60
+ def generate_stories(prompt: str) -> tuple[str, str, str]:
61
+ """Generates text from the three loaded models based on the user's prompt."""
62
+ if not prompt:
63
+ return "Please enter a prompt to start.", "", ""
64
+
65
+ # We use 'max_new_tokens' to control the length of the generated story.
66
+ # Increased to 200 for more substantial story continuations.
67
+ params = {"max_new_tokens": 200, "do_sample": True, "temperature": 0.7, "top_p": 0.95}
68
 
69
+ # Generate from all three models
70
+ out1 = generator1(prompt, **params)[0]['generated_text']
71
+ out2 = generator2(prompt, **params)[0]['generated_text']
72
+ out3 = generator3(prompt, **params)[0]['generated_text']
73
 
74
  return out1, out2, out3
75
 
76
  # --- Gradio Interface ---
77
+ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
78
+ gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
79
+ gr.Markdown(DESCRIPTION)
80
 
81
  with gr.Row():
82
  with gr.Column(scale=1):
83
  input_area = gr.TextArea(
84
+ lines=5,
85
+ label="Your Story Prompt 👇",
86
+ placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'"
87
  )
88
+ generate_button = gr.Button("Weave a Story ✨", variant="primary")
89
 
90
  with gr.Column(scale=2):
91
  with gr.Tabs():
92
+ with gr.TabItem("Mistral-7B"):
93
+ gen1_output = gr.TextArea(label="Mistral's Tale", interactive=False, lines=12)
94
+ with gr.TabItem("Gemma-7B"):
95
+ gen2_output = gr.TextArea(label="Gemma's Chronicle", interactive=False, lines=12)
96
+ with gr.TabItem("Llama-3-8B"):
97
+ gen3_output = gr.TextArea(label="Llama's Legend", interactive=False, lines=12)
98
 
99
  gr.Examples(
100
  examples=examples,
101
  inputs=input_area,
102
+ label="Example Story Starters (Click to use)"
103
  )
104
 
105
  generate_button.click(
106
+ fn=generate_stories,
107
  inputs=input_area,
108
  outputs=[gen1_output, gen2_output, gen3_output],
109
  api_name="generate"
110
  )
111
 
112
  if __name__ == "__main__":
113
+ demo.launch()