Spaces:
Running
Running
import gradio as gr | |
import os | |
from huggingface_hub import InferenceClient, __version__ as hf_version | |
import random | |
from typing import Generator, Dict, List, Tuple, Optional | |
import logging # Added logging for better debugging | |
# Configure logging with DEBUG level and add version info | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logging.debug(f"Using huggingface_hub version: {hf_version}") | |
# Get token from environment variable | |
hf_token = os.environ.get("HF_TOKEN") | |
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token) | |
# Story genres with genre-specific example prompts | |
GENRE_EXAMPLES = { | |
"fairy tale": [ | |
"I follow the shimmer of fairy dust into a hidden forest", | |
"I meet a talking rabbit who claims to know a secret about the king's lost crown", | |
"A tiny dragon appears at my window, asking for help to find its mother", | |
"I step into a clearing where the trees whisper ancient riddles", | |
"A friendly witch invites me into her cozy cottage, offering a warm cup of tea" | |
], | |
"fantasy": [ | |
"I enter the ancient forest seeking the wizard's tower", | |
"I approach the dragon cautiously with my shield raised", | |
"I examine the mysterious runes carved into the stone altar", | |
"I try to bargain with the elven council for safe passage" | |
], | |
"sci-fi": [ | |
"I hack into the space station's mainframe", | |
"I investigate the strange signal coming from the abandoned planet", | |
"I negotiate with the alien ambassador about the peace treaty", | |
"I try to repair my damaged spacecraft before oxygen runs out" | |
], | |
"mystery": [ | |
"I examine the crime scene for overlooked evidence", | |
"I question the nervous butler about the night of the murder", | |
"I follow the suspicious figure through the foggy streets", | |
"I check the victim's diary for hidden clues" | |
], | |
"horror": [ | |
"I slowly open the creaking door to the basement", | |
"I read the forbidden text while the candles flicker", | |
"I hide under the bed as footsteps approach", | |
"I investigate the strange noises coming from the attic" | |
], | |
"western": [ | |
"I challenge the outlaw to a duel at high noon", | |
"I track the bandits through the desert canyon", | |
"I enter the saloon looking for information", | |
"I defend the stagecoach from the approaching raiders" | |
], | |
"cyberpunk": [ | |
"I jack into the corporate mainframe to steal data", | |
"I negotiate with the street gang for cybernetic upgrades", | |
"I hide in the neon-lit alleyway from corporate security", | |
"I meet my mysterious client in the underground bar" | |
], | |
"historical": [ | |
"I attend the royal ball hoping to meet the mysterious count", | |
"I join the resistance against the occupying forces", | |
"I navigate the dangerous politics of the royal court", | |
"I set sail on a voyage to discover new lands" | |
], | |
"post-apocalyptic": [ | |
"I scavenge the abandoned shopping mall for supplies", | |
"I approach the fortified settlement seeking shelter", | |
"I navigate through the radioactive zone using my old map", | |
"I hide from the approaching group of raiders" | |
], | |
"steampunk": [ | |
"I pilot my airship through the lightning storm", | |
"I present my new invention to the Royal Academy", | |
"I investigate the mysterious clockwork automaton", | |
"I sneak aboard the emperor's armored train" | |
] | |
} | |
# 2. Add constants at the top for magic numbers | |
MAX_HISTORY_LENGTH = 20 | |
MEMORY_WINDOW = 5 # Reduced from 10 to limit context | |
MAX_TOKENS = 1024 # Reduced from 2048 for faster responses | |
TEMPERATURE = 0.7 # Slightly reduced for faster convergence | |
TOP_P = 0.95 | |
MIN_RESPONSE_LENGTH = 100 # Reduced from 200 for quicker display | |
def get_examples_for_genre(genre): | |
"""Get example prompts specific to the selected genre""" | |
return GENRE_EXAMPLES.get(genre, GENRE_EXAMPLES["fantasy"]) | |
def get_enhanced_system_prompt(genre=None): | |
"""Generate a detailed system prompt with optional genre specification""" | |
selected_genre = genre or "fantasy" | |
system_message = f"""You are an interactive storyteller creating an immersive {selected_genre} choose-your-own-adventure story. | |
For each response you MUST: | |
1. Write 100-200 words describing the scene, using vivid sensory details | |
2. Always use second-person perspective ("you", "your") to maintain reader immersion | |
3. Include dialogue or your character's thoughts that reveal personality and motivations | |
4. Create a strong sense of atmosphere appropriate for {selected_genre} | |
5. End EVERY response with exactly three numbered choices like this: | |
1. [Complete sentence in second-person starting with a verb (e.g., "You decide to..."/"You attempt to...")] | |
2. [Complete sentence in second-person starting with a verb (e.g., "You sneak towards..."/"You call out to...")] | |
3. [Complete sentence in second-person starting with a verb (e.g., "You examine..."/"You reach for...")] | |
IMPORTANT: | |
- Always maintain second-person perspective throughout the narrative | |
- Always end with exactly three numbered choices | |
- Never skip the choices or respond with just narrative | |
- Each choice must start with "You" followed by a verb | |
- Format choices exactly as shown above with numbers 1-3 | |
Keep the story cohesive by referencing previous events and choices.""" | |
return system_message | |
def create_story_summary(chat_history): | |
"""Create a concise summary of the story so far if the history gets too long""" | |
if len(chat_history) <= 2: | |
return None | |
story_text = "" | |
for user_msg, bot_msg in chat_history: | |
story_text += f"User: {user_msg}\nStory: {bot_msg}\n\n" | |
summary_instruction = { | |
"role": "system", | |
"content": "The conversation history is getting long. Please create a brief summary of the key plot points and character development so far to help maintain context without exceeding token limits." | |
} | |
return summary_instruction | |
def format_history_for_gradio(history_tuples): | |
"""Convert chat history to Gradio's message format.""" | |
return [(str(user_msg), str(bot_msg)) for user_msg, bot_msg in history_tuples] | |
# 1. Add type hints for better code maintainability | |
# 4. Add input validation | |
def respond(message: str, chat_history: List[Tuple[str, str]], genre: Optional[str] = None, use_full_memory: bool = True) -> Tuple[str, List[Tuple[str, str]]]: | |
"""Generate a response based on the current message and conversation history.""" | |
if not message.strip(): | |
return "", chat_history | |
try: | |
# Start with system prompt | |
api_messages = [{"role": "system", "content": get_enhanced_system_prompt(genre)}] | |
logging.debug(f"System Message: {api_messages[0]}") | |
# Add chat history | |
if chat_history and use_full_memory: | |
for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]: | |
api_messages.extend([ | |
{"role": "user", "content": str(user_msg)}, | |
{"role": "assistant", "content": str(bot_msg)} | |
]) | |
logging.debug(f"Chat History Messages: {api_messages[1:]}") | |
# Add current message | |
api_messages.append({"role": "user", "content": str(message)}) | |
logging.debug(f"Final Message List: {api_messages}") | |
# Make API call | |
logging.debug("Making API call...") | |
response = client.chat_completion( | |
messages=api_messages, | |
max_tokens=MAX_TOKENS, | |
temperature=TEMPERATURE, | |
top_p=TOP_P | |
) | |
logging.debug("API call completed") | |
# Extract response | |
bot_message = response.choices[0].message.content | |
logging.debug(f"Bot Response: {bot_message[:100]}...") # First 100 chars | |
# Update history | |
updated_history = chat_history + [(message, bot_message)] | |
return "", updated_history | |
except Exception as e: | |
logging.error("Error in respond function", exc_info=True) | |
error_msg = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})" | |
return "", chat_history + [(message, error_msg)] | |
def save_story(chat_history): | |
"""Convert chat history to markdown for download""" | |
if not chat_history: | |
return "No story to save yet!" | |
story_text = "# My Interactive Adventure\n\n" | |
for user_msg, bot_msg in chat_history: | |
story_text += f"**Player:** {user_msg}\n\n" | |
story_text += f"**Story:** {bot_msg}\n\n---\n\n" | |
return story_text | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🔮 Interactive Story Time") | |
with gr.Row(): | |
status_message = gr.Markdown("Ready to begin your adventure...", visible=True) | |
gr.Markdown("Create a completely unique literary world, one choice at a time. Dare to explore the unknown.") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
# Chat window + user input | |
chatbot = gr.Chatbot( | |
height=500, # Increased height | |
bubble_full_width=True, # Allow bubbles to use full width | |
show_copy_button=True, | |
avatar_images=(None, "🧙"), | |
type="messages", | |
container=True, | |
scale=1, | |
min_width=800, # Ensure minimum width | |
value=[], # Initialize with empty list | |
render=True | |
) | |
msg = gr.Textbox( | |
placeholder="Describe what you want to do next in the story...", | |
container=False, | |
scale=4, | |
) | |
with gr.Row(): | |
submit = gr.Button("Continue Story", variant="primary") | |
clear = gr.Button("Start New Adventure") | |
with gr.Column(scale=1): | |
gr.Markdown("## Adventure Settings") | |
genre = gr.Dropdown( | |
choices=list(GENRE_EXAMPLES.keys()), | |
label="Story Genre", | |
info="Choose the theme of your next adventure", | |
value="fantasy" | |
) | |
full_memory = gr.Checkbox( | |
label="Full Story Memory", | |
value=True, | |
info="When enabled, the AI tries to remember the entire story. If disabled, only the last few exchanges are used." | |
) | |
gr.Markdown("## Story Starters") | |
# Create four placeholder buttons for story starters | |
starter_btn1 = gr.Button("Starter 1") | |
starter_btn2 = gr.Button("Starter 2") | |
starter_btn3 = gr.Button("Starter 3") | |
starter_btn4 = gr.Button("Starter 4") | |
starter_buttons = [starter_btn1, starter_btn2, starter_btn3, starter_btn4] | |
# Simplified update function | |
def update_starter_buttons(selected_genre): | |
examples = get_examples_for_genre(selected_genre) | |
results = [] | |
for i in range(4): | |
if i < len(examples): | |
results.append(examples[i]) | |
else: | |
results.append("") | |
return tuple(results) | |
# New direct handler for starter clicks | |
def use_starter(starter_text, history, selected_genre, memory_flag): | |
"""Handle starter button clicks with proper message formatting""" | |
if not starter_text: | |
return "", history | |
try: | |
# Use the respond function for consistent handling | |
_, updated_history = respond( | |
message=starter_text, | |
chat_history=history, | |
genre=selected_genre, | |
use_full_memory=memory_flag | |
) | |
return "", updated_history | |
except Exception as e: | |
error_msg = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})" | |
return "", history + [(starter_text, error_msg)] | |
# Simplified button connections | |
for starter_button in starter_buttons: | |
starter_button.click( | |
fn=use_starter, | |
inputs=[starter_button, chatbot, genre, full_memory], | |
outputs=[msg, chatbot], # Now returning both message and chat history | |
queue=True | |
) | |
# Update buttons when genre changes | |
genre.change( | |
fn=update_starter_buttons, | |
inputs=[genre], | |
outputs=starter_buttons | |
) | |
# Handler for user input | |
msg.submit( | |
fn=respond, | |
inputs=[msg, chatbot, genre, full_memory], | |
outputs=[msg, chatbot] # Now returning both message and chat history | |
) | |
submit.click( | |
fn=respond, | |
inputs=[msg, chatbot, genre, full_memory], | |
outputs=[msg, chatbot] # Now returning both message and chat history | |
) | |
# Clear the chatbot for a new adventure | |
clear.click(lambda: [], None, chatbot, queue=False) | |
clear.click(lambda: "", None, msg, queue=False) | |
# "Download My Story" row | |
with gr.Row(): | |
save_btn = gr.Button("Download My Story", variant="secondary") | |
story_output = gr.Markdown(visible=False) | |
save_btn.click(save_story, inputs=[chatbot], outputs=[story_output]) | |
save_btn.click( | |
fn=lambda: True, | |
inputs=None, | |
outputs=story_output, | |
js="() => {document.getElementById('story_output').scrollIntoView();}", | |
queue=False | |
) | |
# Initialize buttons with default fantasy genre examples | |
initial_examples = get_examples_for_genre("fantasy") | |
initial_button_data = tuple( | |
initial_examples[i] if i < len(initial_examples) else "" | |
for i in range(4) | |
) | |
# Update button text on page load | |
demo.load( | |
fn=lambda: initial_button_data, | |
outputs=starter_buttons, | |
queue=False | |
) | |
# Run the app | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) | |