Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from huggingface_hub import InferenceClient, __version__ as hf_version | |
| import random | |
| from typing import Generator, Dict, List, Tuple, Optional | |
| import logging # Added logging for better debugging | |
| # Configure logging with DEBUG level and add version info | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logging.debug(f"Using huggingface_hub version: {hf_version}") | |
| # Get token from environment variable | |
| hf_token = os.environ.get("HF_TOKEN") | |
| client = InferenceClient("HuggingFaceH4/zephyr-7b-beta", token=hf_token) | |
| # Story genres with genre-specific example prompts | |
| GENRE_EXAMPLES = { | |
| "fairy tale": [ | |
| "I follow the shimmer of fairy dust into a hidden forest", | |
| "A tiny dragon appears at my window, asking for help to find its mother", | |
| "A friendly witch invites me into her cozy cottage, offering a warm cup of tea" | |
| ], | |
| "fantasy": [ | |
| "I enter the ancient forest seeking the wizard's tower", | |
| "I approach the dragon cautiously with my shield raised", | |
| "I try to bargain with the elven council for safe passage" | |
| ], | |
| "sci-fi": [ | |
| "I investigate the strange signal coming from the abandoned planet", | |
| "I negotiate with the alien ambassador about the peace treaty", | |
| "I try to repair my damaged spacecraft before oxygen runs out" | |
| ], | |
| "mystery": [ | |
| "I examine the crime scene for overlooked evidence", | |
| "I question the nervous butler about the night of the murder", | |
| "I follow the suspicious figure through the foggy streets" | |
| ], | |
| "horror": [ | |
| "I read the forbidden text while the candles flicker", | |
| "I hide under the bed as footsteps approach", | |
| "I investigate the strange noises coming from the attic" | |
| ], | |
| "western": [ | |
| "I challenge the outlaw to a duel at high noon", | |
| "I track the bandits through the desert canyon", | |
| "I enter the saloon looking for information" | |
| ], | |
| "cyberpunk": [ | |
| "I jack into the corporate mainframe to steal data", | |
| "I hide in the neon-lit alleyway from corporate security", | |
| "I meet my mysterious client in the underground bar" | |
| ], | |
| "historical": [ | |
| "I join the resistance against the occupying forces", | |
| "I navigate the dangerous politics of the royal court", | |
| "I set sail on a voyage to discover new lands" | |
| ], | |
| "post-apocalyptic": [ | |
| "I scavenge the abandoned shopping mall for supplies", | |
| "I navigate through the radioactive zone using my old map", | |
| "I hide from the approaching group of raiders" | |
| ], | |
| "steampunk": [ | |
| "I pilot my airship through the lightning storm", | |
| "I present my new invention to the Royal Academy", | |
| "I sneak aboard the emperor's armored train" | |
| ] | |
| } | |
| # 2. Add constants at the top for magic numbers | |
| MAX_HISTORY_LENGTH = 20 | |
| MEMORY_WINDOW = 5 # Reduced from 10 to limit context | |
| MAX_TOKENS = 1024 # Reduced from 2048 for faster responses | |
| TEMPERATURE = 0.7 # Slightly reduced for faster convergence | |
| TOP_P = 0.95 | |
| MIN_RESPONSE_LENGTH = 100 # Reduced from 200 for quicker display | |
| def get_examples_for_genre(genre): | |
| """Get example prompts specific to the selected genre""" | |
| return GENRE_EXAMPLES.get(genre, GENRE_EXAMPLES["fantasy"]) | |
| def get_enhanced_system_prompt(genre=None): | |
| """Generate a detailed system prompt with optional genre specification""" | |
| selected_genre = genre or "fantasy" | |
| system_message = f"""You are an interactive storyteller creating an immersive {selected_genre} choose-your-own-adventure story. | |
| For each response you MUST: | |
| 1. Write 100-200 words describing the scene, using vivid sensory details | |
| 2. Always use second-person perspective ("you", "your") to maintain reader immersion | |
| 3. Include dialogue or your character's thoughts that reveal personality and motivations | |
| 4. Create a strong sense of atmosphere appropriate for {selected_genre} | |
| 5. End with EXACTLY THREE numbered choices and NOTHING ELSE AFTER THEM: | |
| 1. [Complete sentence in second-person starting with a verb] | |
| 2. [Complete sentence in second-person starting with a verb] | |
| 3. [Complete sentence in second-person starting with a verb] | |
| CRITICAL RULES: | |
| - Provide only ONE set of three choices at the very end of your response | |
| - Never continue the story after giving choices | |
| - Never provide additional choices | |
| - Keep all narrative before the choices | |
| - End every response with exactly three numbered options | |
| - Each choice must start with "You" followed by a verb | |
| Remember: The story continues ONLY when the player makes a choice.""" | |
| return system_message | |
| def create_story_summary(chat_history): | |
| """Create a concise summary of the story so far if the history gets too long""" | |
| if len(chat_history) <= 2: | |
| return None | |
| story_text = "" | |
| for i in range(0, len(chat_history), 2): | |
| if i+1 < len(chat_history): | |
| story_text += f"User: {chat_history[i]}\nStory: {chat_history[i+1]}\n\n" | |
| summary_instruction = { | |
| "role": "system", | |
| "content": "The conversation history is getting long. Please create a brief summary of the key plot points and character development so far to help maintain context without exceeding token limits." | |
| } | |
| return summary_instruction | |
| # Modified function for proper Gradio format (lists) | |
| def respond(message: str, chat_history: List[Tuple[str, str]], genre: Optional[str] = None, use_full_memory: bool = True) -> Tuple[str, List[Tuple[str, str]]]: | |
| """Generate a response based on the current message and conversation history.""" | |
| if not message.strip(): | |
| return "", chat_history | |
| try: | |
| # Start with system prompt | |
| api_messages = [{"role": "system", "content": get_enhanced_system_prompt(genre)}] | |
| logging.debug(f"System Message: {api_messages[0]}") | |
| # Add chat history - convert from tuples to API format | |
| if chat_history and use_full_memory: | |
| for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]: | |
| api_messages.extend([ | |
| {"role": "user", "content": str(user_msg)}, | |
| {"role": "assistant", "content": str(bot_msg)} | |
| ]) | |
| logging.debug(f"Chat History Messages: {api_messages[1:]}") | |
| # Add current message | |
| api_messages.append({"role": "user", "content": str(message)}) | |
| logging.debug(f"Final Message List: {api_messages}") | |
| # Make API call without timeout | |
| logging.debug("Making API call...") | |
| response = client.chat_completion( | |
| messages=api_messages, | |
| max_tokens=MAX_TOKENS, | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P | |
| ) | |
| logging.debug("API call completed") | |
| # Extract response | |
| bot_message = response.choices[0].message.content | |
| logging.debug(f"Bot Response: {bot_message[:100]}...") | |
| # Update history using tuple format [(user_msg, bot_msg), ...] | |
| updated_history = list(chat_history) # Create a copy | |
| updated_history.append((message, bot_message)) # Add as tuple | |
| return "", updated_history | |
| except Exception as e: | |
| logging.error("Error in respond function", exc_info=True) | |
| error_msg = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})" | |
| return "", list(chat_history) + [(message, error_msg)] | |
| def save_story(chat_history): | |
| """Convert chat history to markdown and return as downloadable file""" | |
| if not chat_history: | |
| return gr.File.update(value=None) | |
| try: | |
| story_text = "# My Adventure\n\n" | |
| for user_msg, bot_msg in chat_history: | |
| story_text += f"**Player:** {user_msg}\n\n" | |
| story_text += f"**Story:** {bot_msg}\n\n---\n\n" | |
| # Create temporary file | |
| temp_file = "my_story.md" | |
| with open(temp_file, "w", encoding="utf-8") as f: | |
| f.write(story_text) | |
| return temp_file | |
| except Exception as e: | |
| logging.error(f"Error saving story: {e}") | |
| return gr.File.update(value=None) | |
| # Add this function to get a custom avatar image URL | |
| def get_storyteller_avatar_url(): | |
| """Get a URL for the storyteller avatar from a free image service""" | |
| # Using an external wizard avatar image | |
| return "https://api.dicebear.com/7.x/bottts/svg?seed=wizard&backgroundColor=b6e3f4&eyes=bulging" | |
| # Add this before your gr.Blocks definition | |
| custom_css = """ | |
| .compact-file-output > div { | |
| min-height: 0 !important; | |
| padding: 0 !important; | |
| } | |
| .compact-file-output .file-preview { | |
| margin: 0 !important; | |
| display: flex; | |
| align-items: center; | |
| } | |
| .compact-btn { | |
| padding: 0.5rem !important; | |
| min-height: 0 !important; | |
| height: auto !important; | |
| line-height: 1.2 !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| # Header section with improved instructions | |
| gr.Markdown(""" | |
| # 🔮 AI Story Studio | |
| **Collaborate with AI to craft your own adventure, one scene at a time.** | |
| Pick a genre, start with a prompt or write your own, and guide the story with your choices. | |
| > **Tip:** The more detail you provide, the deeper the story becomes. | |
| """) | |
| wizard_avatar = get_storyteller_avatar_url() | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # Chat window + user input - USING LIST FORMAT | |
| chatbot = gr.Chatbot( | |
| height=400, | |
| bubble_full_width=True, | |
| show_copy_button=True, | |
| avatar_images=(None, wizard_avatar), | |
| container=True, | |
| scale=1, | |
| min_width=800, | |
| value=[], # Empty list for messages | |
| render=True | |
| ) | |
| msg = gr.Textbox( | |
| placeholder="Describe your next move...", | |
| container=False, | |
| scale=4, | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Continue Story", variant="primary") | |
| clear = gr.Button("Start New Adventure") | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Adventure Settings") | |
| genre = gr.Dropdown( | |
| choices=list(GENRE_EXAMPLES.keys()), | |
| label="Story Genre", | |
| info="Choose the theme of your next adventure", | |
| value="fantasy" | |
| ) | |
| full_memory = gr.Checkbox( | |
| label="Full Story Memory", | |
| value=True, | |
| info="When enabled, the AI tries to remember the entire story. If disabled, only the last few exchanges are used." | |
| ) | |
| gr.Markdown("## Story Starters") | |
| # Create three placeholder buttons for story starters | |
| starter_btn1 = gr.Button("Starter 1", scale=1, min_width=250, elem_classes="compact-btn") | |
| starter_btn2 = gr.Button("Starter 2", scale=1, min_width=250, elem_classes="compact-btn") | |
| starter_btn3 = gr.Button("Starter 3", scale=1, min_width=250, elem_classes="compact-btn") | |
| starter_buttons = [starter_btn1, starter_btn2, starter_btn3] | |
| # Simplified update function | |
| def update_starter_buttons(selected_genre): | |
| examples = get_examples_for_genre(selected_genre) | |
| results = [] | |
| for i in range(3): | |
| if i < len(examples): | |
| results.append(examples[i]) | |
| else: | |
| results.append("") | |
| return tuple(results) | |
| # New direct handler for starter clicks | |
| def use_starter(starter_text: str, history: List[Tuple[str, str]], selected_genre: str, memory_flag: bool) -> Tuple[str, List[Tuple[str, str]]]: | |
| """Handle starter button clicks with proper message formatting""" | |
| if not starter_text: | |
| return "", history | |
| try: | |
| # Use the respond function for consistent handling | |
| _, updated_history = respond( | |
| message=starter_text, | |
| chat_history=history, | |
| genre=selected_genre, | |
| use_full_memory=memory_flag | |
| ) | |
| return "", updated_history | |
| except Exception as e: | |
| error_msg = f"Story magic temporarily interrupted. Please try again. (Error: {str(e)})" | |
| return "", list(history) + [(starter_text, error_msg)] | |
| # Simplified button connections | |
| for starter_button in starter_buttons: | |
| starter_button.click( | |
| fn=use_starter, | |
| inputs=[starter_button, chatbot, genre, full_memory], | |
| outputs=[msg, chatbot], | |
| queue=True | |
| ) | |
| # Update buttons when genre changes | |
| genre.change( | |
| fn=update_starter_buttons, | |
| inputs=[genre], | |
| outputs=starter_buttons | |
| ) | |
| # Handler for user input | |
| msg.submit( | |
| fn=respond, | |
| inputs=[msg, chatbot, genre, full_memory], | |
| outputs=[msg, chatbot] | |
| ) | |
| submit.click( | |
| fn=respond, | |
| inputs=[msg, chatbot, genre, full_memory], | |
| outputs=[msg, chatbot] | |
| ) | |
| # Clear the chatbot for a new adventure | |
| clear.click(lambda: [], None, chatbot, queue=False) | |
| clear.click(lambda: "", None, msg, queue=False) | |
| # "Download My Story" row with improved layout | |
| with gr.Row(equal_height=True): # Force equal height for all children | |
| # Use Column for the button to control width | |
| with gr.Column(scale=4): | |
| save_btn = gr.Button("Download My Story", variant="secondary", size="lg") | |
| # Use Column for the file output with matching height | |
| with gr.Column(scale=1): | |
| story_output = gr.File( | |
| label=None, # Remove the label that adds extra height | |
| file_count="single", | |
| file_types=[".md"], | |
| interactive=False, | |
| visible=True, | |
| elem_classes="compact-file-output" # Optional: for custom CSS styling | |
| ) | |
| # Connect the save button to the save_story function | |
| save_btn.click( | |
| fn=save_story, | |
| inputs=[chatbot], | |
| outputs=story_output, | |
| queue=False # Process immediately | |
| ) | |
| # Initialize buttons with default fantasy genre examples | |
| initial_examples = get_examples_for_genre("fantasy") | |
| initial_button_data = tuple( | |
| initial_examples[i] if i < len(initial_examples) else "" | |
| for i in range(3) | |
| ) | |
| # Update button text on page load | |
| demo.load( | |
| fn=lambda: initial_button_data, | |
| outputs=starter_buttons, | |
| queue=False | |
| ) | |
| # Run the app | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |