Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import pipeline | |
import os | |
# --- App Configuration --- | |
title = "📗 Health and Mindful Story Gen ❤️" | |
description = """ | |
Enter a topic or a starting sentence related to health, mindfulness, or well-being. | |
The app will generate continuations from three different language models. | |
**Note:** These models are very large. Initial loading and first-time generation may take a few minutes. | |
""" | |
# Define the examples to show in the interface | |
examples = [ | |
["Mental Body Scan"], | |
["Stretch, Calm, Breath"], | |
["Relaxed Seat Breath"], | |
["Walk Feel"], | |
["Brain gamification"], | |
["Alleviating stress"], | |
["Helping breathing, satisfaction"], | |
["Relieve Stress, Build Support"], | |
["The Relaxation Response"], | |
["Taking Deep Breaths"], | |
["Delete Not Helpful Thoughts"], | |
["Strengthen Helpful Thoughts"], | |
["Reprogram Pain and Stress Reactions"], | |
["How to Sleep Better and Find Joy"], | |
["Yoga for deep sleep"], | |
["Being a Happier and Healthier Person"], | |
["Relieve chronic pain by"], | |
["Use Mindfulness to Affect Well Being"], | |
["Build and Boost Mental Strength"], | |
["Spending Time Outdoors"], | |
["Daily Routine Tasks"], | |
["Eating and Drinking - Find Healthy Nutrition Habits"], | |
["Drinking - Find Reasons and Cut Back or Quit Entirely"], | |
["Feel better each day when you awake by"], | |
["Feel better physically by"], | |
["Practicing mindfulness each day"], | |
["Be happier by"], | |
["Meditation can improve health by"], | |
["Spending time outdoors helps to"], | |
["Stress is relieved by quieting your mind, getting exercise and time with nature"], | |
["Break the cycle of stress and anxiety"], | |
["Feel calm in stressful situations"], | |
["Deal with work pressure by"], | |
["Learn to reduce feelings of being overwhelmed"] | |
] | |
# --- Model Initialization --- | |
# WARNING: Loading these models requires significant hardware (ideally a GPU with >24GB VRAM). | |
# 'device_map="auto"' and 'torch_dtype' require the 'accelerate' library. | |
# Install dependencies: pip install gradio transformers torch accelerate | |
try: | |
print("Initializing models... This may take several minutes.") | |
# Using device_map="auto" to automatically use available GPUs. | |
# Using torch_dtype="auto" to load models in half-precision (float16/bfloat16) to save memory. | |
generator1 = pipeline("text-generation", model="gpt2-large", device_map="auto") | |
print("GPT-2 Large loaded.") | |
generator2 = pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B", torch_dtype="auto", device_map="auto") | |
print("GPT-Neo 2.7B loaded.") | |
generator3 = pipeline("text-generation", model="EleutherAI/gpt-j-6B", torch_dtype="auto", device_map="auto") | |
print("GPT-J 6B loaded.") | |
print("All models loaded successfully! ✅") | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
print("Please ensure you have 'torch' and 'accelerate' installed and have sufficient VRAM.") | |
# Create dummy functions if models fail to load, so the app can still launch. | |
def failed_generator(prompt, **kwargs): | |
return [{'generated_text': "Model failed to load. Check console for errors."}] | |
generator1 = generator2 = generator3 = failed_generator | |
# --- App Logic --- | |
def generate_outputs(input_text: str) -> tuple[str, str, str]: | |
"""Generates text from the three loaded models.""" | |
# Using 'max_new_tokens' is preferred over 'max_length' to specify the length of the generated text only. | |
params = {"max_new_tokens": 60, "num_return_sequences": 1} | |
out1 = generator1(input_text, **params)[0]['generated_text'] | |
out2 = generator2(input_text, **params)[0]['generated_text'] | |
out3 = generator3(input_text, **params)[0]['generated_text'] | |
return out1, out2, out3 | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") | |
gr.Markdown(description) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_area = gr.TextArea( | |
lines=3, | |
label="Your starting prompt 👇", | |
placeholder="e.g., 'To relieve stress, I will try...'" | |
) | |
generate_button = gr.Button("Generate ✨", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Tabs(): | |
with gr.TabItem("GPT-2 Large"): | |
gen1_output = gr.TextArea(label="GPT-2 Large Output", interactive=False, lines=7) | |
with gr.TabItem("GPT-Neo 2.7B"): | |
gen2_output = gr.TextArea(label="GPT-Neo 2.7B Output", interactive=False, lines=7) | |
with gr.TabItem("GPT-J 6B"): | |
gen3_output = gr.TextArea(label="GPT-J 6B Output", interactive=False, lines=7) | |
gr.Examples( | |
examples=examples, | |
inputs=input_area, | |
label="Example Prompts (Click to use)" | |
) | |
generate_button.click( | |
fn=generate_outputs, | |
inputs=input_area, | |
outputs=[gen1_output, gen2_output, gen3_output], | |
api_name="generate" | |
) | |
if __name__ == "__main__": | |
demo.launch() |