awacke1's picture
Update app.py
5fa7137 verified
raw
history blame
5.17 kB
import gradio as gr
import torch
from transformers import pipeline
import os
# --- App Configuration ---
title = "📗 Health and Mindful Story Gen ❤️"
description = """
Enter a topic or a starting sentence related to health, mindfulness, or well-being.
The app will generate continuations from three different language models.
**Note:** These models are very large. Initial loading and first-time generation may take a few minutes.
"""
# Define the examples to show in the interface
examples = [
["Mental Body Scan"],
["Stretch, Calm, Breath"],
["Relaxed Seat Breath"],
["Walk Feel"],
["Brain gamification"],
["Alleviating stress"],
["Helping breathing, satisfaction"],
["Relieve Stress, Build Support"],
["The Relaxation Response"],
["Taking Deep Breaths"],
["Delete Not Helpful Thoughts"],
["Strengthen Helpful Thoughts"],
["Reprogram Pain and Stress Reactions"],
["How to Sleep Better and Find Joy"],
["Yoga for deep sleep"],
["Being a Happier and Healthier Person"],
["Relieve chronic pain by"],
["Use Mindfulness to Affect Well Being"],
["Build and Boost Mental Strength"],
["Spending Time Outdoors"],
["Daily Routine Tasks"],
["Eating and Drinking - Find Healthy Nutrition Habits"],
["Drinking - Find Reasons and Cut Back or Quit Entirely"],
["Feel better each day when you awake by"],
["Feel better physically by"],
["Practicing mindfulness each day"],
["Be happier by"],
["Meditation can improve health by"],
["Spending time outdoors helps to"],
["Stress is relieved by quieting your mind, getting exercise and time with nature"],
["Break the cycle of stress and anxiety"],
["Feel calm in stressful situations"],
["Deal with work pressure by"],
["Learn to reduce feelings of being overwhelmed"]
]
# --- Model Initialization ---
# WARNING: Loading these models requires significant hardware (ideally a GPU with >24GB VRAM).
# 'device_map="auto"' and 'torch_dtype' require the 'accelerate' library.
# Install dependencies: pip install gradio transformers torch accelerate
try:
print("Initializing models... This may take several minutes.")
# Using device_map="auto" to automatically use available GPUs.
# Using torch_dtype="auto" to load models in half-precision (float16/bfloat16) to save memory.
generator1 = pipeline("text-generation", model="gpt2-large", device_map="auto")
print("GPT-2 Large loaded.")
generator2 = pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B", torch_dtype="auto", device_map="auto")
print("GPT-Neo 2.7B loaded.")
generator3 = pipeline("text-generation", model="EleutherAI/gpt-j-6B", torch_dtype="auto", device_map="auto")
print("GPT-J 6B loaded.")
print("All models loaded successfully! ✅")
except Exception as e:
print(f"Error loading models: {e}")
print("Please ensure you have 'torch' and 'accelerate' installed and have sufficient VRAM.")
# Create dummy functions if models fail to load, so the app can still launch.
def failed_generator(prompt, **kwargs):
return [{'generated_text': "Model failed to load. Check console for errors."}]
generator1 = generator2 = generator3 = failed_generator
# --- App Logic ---
def generate_outputs(input_text: str) -> tuple[str, str, str]:
"""Generates text from the three loaded models."""
# Using 'max_new_tokens' is preferred over 'max_length' to specify the length of the generated text only.
params = {"max_new_tokens": 60, "num_return_sequences": 1}
out1 = generator1(input_text, **params)[0]['generated_text']
out2 = generator2(input_text, **params)[0]['generated_text']
out3 = generator3(input_text, **params)[0]['generated_text']
return out1, out2, out3
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
gr.Markdown(description)
with gr.Row():
with gr.Column(scale=1):
input_area = gr.TextArea(
lines=3,
label="Your starting prompt 👇",
placeholder="e.g., 'To relieve stress, I will try...'"
)
generate_button = gr.Button("Generate ✨", variant="primary")
with gr.Column(scale=2):
with gr.Tabs():
with gr.TabItem("GPT-2 Large"):
gen1_output = gr.TextArea(label="GPT-2 Large Output", interactive=False, lines=7)
with gr.TabItem("GPT-Neo 2.7B"):
gen2_output = gr.TextArea(label="GPT-Neo 2.7B Output", interactive=False, lines=7)
with gr.TabItem("GPT-J 6B"):
gen3_output = gr.TextArea(label="GPT-J 6B Output", interactive=False, lines=7)
gr.Examples(
examples=examples,
inputs=input_area,
label="Example Prompts (Click to use)"
)
generate_button.click(
fn=generate_outputs,
inputs=input_area,
outputs=[gen1_output, gen2_output, gen3_output],
api_name="generate"
)
if __name__ == "__main__":
demo.launch()