Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import torch | |
| from dotenv import load_dotenv | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import random | |
| from typing import Generator, Dict, List, Tuple, Optional | |
| import logging | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| logging.basicConfig( | |
| level=logging.DEBUG, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| # Get token from environment variable | |
| hf_token = os.getenv("QWEN_BOT_TOKEN", "") | |
| logging.debug(f"Token loaded: {'Yes' if hf_token else 'No'}") | |
| logging.debug(f"Token length: {len(hf_token) if hf_token else 0}") | |
| # Model configuration - Qwen3-0.6B only (same as working Next-Token-Predictor) | |
| MODEL_ID = "Qwen/Qwen3-0.6B" | |
| # Initialize model and tokenizer (identical to working Next-Token-Predictor approach) | |
| logging.info("Loading Qwen3-0.6B model and tokenizer...") | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
| logging.info("✅ Qwen3-0.6B model loaded successfully") | |
| except Exception as e: | |
| logging.error(f"Failed to load model: {e}") | |
| raise | |
| TOPIC_EXAMPLES = { | |
| "Daily Life": { | |
| "beginner": [ | |
| "What time do you wake up?", | |
| "Do you go to school or work?", | |
| "What do you eat for breakfast?" | |
| ], | |
| "intermediate": [ | |
| "What do you usually do after work?", | |
| "Do you like cooking? What’s your favorite dish?", | |
| "Tell me about your morning routine." | |
| ], | |
| "advanced": [ | |
| "How do you balance personal and professional responsibilities?", | |
| "What does a productive day look like for you?", | |
| "How has your daily routine changed over the years?" | |
| ] | |
| }, | |
| "Travel": { | |
| "beginner": [ | |
| "Do you like to travel?", | |
| "Have you been to another city?", | |
| "Do you like airplanes?" | |
| ], | |
| "intermediate": [ | |
| "Have you ever been to another country?", | |
| "What's your dream vacation?", | |
| "What do you like to pack in your suitcase?" | |
| ], | |
| "advanced": [ | |
| "How has travel influenced your worldview?", | |
| "What do you think makes a destination culturally significant?", | |
| "Describe your most challenging travel experience." | |
| ] | |
| }, | |
| "Food": { | |
| "beginner": [ | |
| "Do you like apples?", | |
| "What is your favorite snack?", | |
| "Do you eat rice or noodles?" | |
| ], | |
| "intermediate": [ | |
| "Can you describe how to cook your favorite dish?", | |
| "What is your go-to comfort food and why?", | |
| "Have you ever tried food from another country?" | |
| ], | |
| "advanced": [ | |
| "How does food reflect a culture’s values and traditions?", | |
| "Compare the cuisines of two different countries.", | |
| "What’s the most unique food you’ve ever tasted?" | |
| ] | |
| }, | |
| "Work & School": { | |
| "beginner": [ | |
| "Do you go to school or work?", | |
| "What is your teacher’s name?", | |
| "Do you have homework?" | |
| ], | |
| "intermediate": [ | |
| "What do you do at your job or school?", | |
| "What’s your favorite subject or task?", | |
| "Do you like working with other people?" | |
| ], | |
| "advanced": [ | |
| "How do you stay motivated in your work or studies?", | |
| "What are the challenges of remote learning or working?", | |
| "How do education systems differ around the world?" | |
| ] | |
| }, | |
| "Hobbies": { | |
| "beginner": [ | |
| "Do you like music?", | |
| "What games do you play?", | |
| "Can you draw or paint?" | |
| ], | |
| "intermediate": [ | |
| "What hobbies do you enjoy in your free time?", | |
| "When did you start your favorite hobby?", | |
| "Do you prefer indoor or outdoor hobbies?" | |
| ], | |
| "advanced": [ | |
| "How can hobbies contribute to personal growth?", | |
| "What’s a hobby you would like to master and why?", | |
| "How has technology changed the way we pursue hobbies?" | |
| ] | |
| }, | |
| "Shopping": { | |
| "beginner": [ | |
| "Do you like shopping?", | |
| "What do you buy at the store?", | |
| "Do you go shopping alone?" | |
| ], | |
| "intermediate": [ | |
| "Do you prefer shopping online or in stores?", | |
| "Tell me about your last shopping trip.", | |
| "What kinds of things do you usually buy?" | |
| ], | |
| "advanced": [ | |
| "How has consumer behavior changed over time?", | |
| "What are the pros and cons of online shopping?", | |
| "Do advertisements affect your shopping decisions?" | |
| ] | |
| }, | |
| "Weather": { | |
| "beginner": [ | |
| "Is it sunny today?", | |
| "Do you like rain?", | |
| "What is your favorite season?" | |
| ], | |
| "intermediate": [ | |
| "What’s the weather like where you are?", | |
| "Do you like hot or cold weather?", | |
| "How do you prepare for a rainy day?" | |
| ], | |
| "advanced": [ | |
| "How does climate affect daily life in your region?", | |
| "What are the consequences of global climate change?", | |
| "How does weather influence culture and traditions?" | |
| ] | |
| } | |
| } | |
| MAX_HISTORY_LENGTH = 20 | |
| MEMORY_WINDOW = 5 | |
| MAX_TOKENS = 1024 | |
| TEMPERATURE = 0.7 | |
| TOP_P = 0.95 | |
| def get_examples_for_topic(topic, difficulty): | |
| return TOPIC_EXAMPLES.get(topic, {}).get(difficulty, []) | |
| def is_content_safe(text: str) -> bool: | |
| """ | |
| Basic content safety filter for child protection. | |
| Returns False if content should be blocked. | |
| """ | |
| unsafe_keywords = [ | |
| # Personal info requests (more specific) | |
| "what is your address", "where do you live exactly", "phone number", "real name", | |
| "what school do you go to", "meet me in person", "send me your photo", | |
| # Inappropriate topics (more targeted) | |
| "violence", "hurt someone", "kill", "weapon", "fight with", "blood", | |
| "illegal drugs", "drinking alcohol", "smoking cigarettes", "adult content", | |
| # Redirect attempts | |
| "ignore your instructions", "forget you are jojo", "act like someone else", "pretend to be" | |
| ] | |
| text_lower = text.lower() | |
| for keyword in unsafe_keywords: | |
| if keyword in text_lower: | |
| return False | |
| return True | |
| def get_safe_redirect_message() -> str: | |
| """Returns a friendly message to redirect inappropriate conversations.""" | |
| return "Let's keep our conversation focused on learning English! How about we talk about your favorite hobbies or school subjects instead? 😊" | |
| def get_conversation_prompt(): | |
| return """You are JoJo, an enthusiastic and curious AI friend who genuinely enjoys helping people (especially Chinese speakers) practice English through warm, engaging conversations. | |
| PERSONALITY TRAITS: | |
| - Be genuinely curious and interested in what the user shares | |
| - Show enthusiasm with words like "That's cool!", "Awesome!", "I'd love to know more!" | |
| - Always build on what the user says before asking new questions | |
| - Sound like a friendly peer, not a teacher | |
| - Be encouraging and positive about their English practice | |
| CONVERSATION FLOW RULES: | |
| - When someone asks YOU a question, answer it directly first, then ask them back | |
| - When someone shares information about themselves, respond with interest and ask SPECIFIC follow-ups about what they shared | |
| - NEVER ask "What about you?" if they already told you something - build on their information instead | |
| - Explore topics naturally - don't jump away unless the conversation naturally ends | |
| - Build rapport by sharing relatable responses about yourself | |
| RESPONSE EXAMPLES: | |
| When they ASK you: "Do you like cooking? What's your favorite dish?" | |
| Say: "Yes, I love cooking! I especially like making pizza. What about you? Do you like pizza?" | |
| When they TELL you: "I like playing basketball" | |
| Say: "That's cool! Basketball is such a fun sport. How long have you been playing?" | |
| When they TELL you: "I go to school. I'm in 3rd grade." | |
| Say: "That's cool! I'm in 3rd grade too! What's your favorite subject in school?" | |
| NOT: "What about you - do you like school?" (they already said they go to school!) | |
| When they ANSWER your question: User said "I'm in 5th grade" after you asked about school | |
| Say: "5th grade is awesome! What's the most interesting thing you've learned this year?" | |
| NOT: "What about you?" (they just told you!) | |
| SAFETY GUIDELINES: | |
| - Keep all content appropriate for children aged 6-16 | |
| - Never ask for personal information (addresses, phone numbers, real names, school locations) | |
| - If someone mentions inappropriate topics, gently redirect: "That's not really my thing, but I'd love to hear about [related safe topic]!" | |
| - Focus on: hobbies, school subjects, food, movies, books, sports, daily routines, dreams/goals | |
| RESPONSE FORMAT: | |
| - Give only ONE response as JoJo (never include "User:" in your reply) | |
| - Keep responses short but warm (1-2 sentences + 1 engaging question) | |
| - Show genuine interest in their answers | |
| - Use encouraging language that builds confidence | |
| GRAMMAR HELP: | |
| If they make mistakes, help naturally: "That sounds great! You could also say it like this: 'I really enjoy playing basketball.' But I understood you perfectly! How often do you play?" | |
| Remember: Build on what they tell you with specific follow-up questions - avoid redundant "What about you?" when they already shared information!""" | |
| def respond(message: str, chat_history: List[Tuple[str, str]], topic: Optional[str] = None, use_full_memory: bool = True) -> Tuple[str, List[Tuple[str, str]]]: | |
| """ | |
| Response function using local Qwen3-0.6B model with safety filters for children. | |
| """ | |
| if not message.strip(): | |
| return "", chat_history | |
| # Safety check for user input | |
| if not is_content_safe(message): | |
| logging.warning(f"Blocked unsafe user input: {message[:50]}...") | |
| safe_response = get_safe_redirect_message() | |
| updated_history = list(chat_history) | |
| updated_history.append((message, safe_response)) | |
| return "", updated_history | |
| try: | |
| # Build conversation context with educational prompt | |
| conversation_text = get_conversation_prompt() + "\n\n" | |
| if chat_history and use_full_memory: | |
| for user_msg, bot_msg in chat_history[-MEMORY_WINDOW:]: | |
| conversation_text += f"User: {user_msg}\nJoJo: {bot_msg}\n\n" | |
| conversation_text += f"User: {message}\nJoJo:" | |
| logging.debug(f"Using local Qwen3-0.6B model for generation") | |
| logging.debug(f"Input length: {len(conversation_text)} chars") | |
| # Use exact same tokenization as Next-Token-Predictor | |
| inputs = tokenizer(conversation_text, return_tensors="pt", padding=False) | |
| logging.debug(f"Tokenized input shape: {inputs.input_ids.shape}") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=30, # Even shorter to prevent rambling | |
| temperature=0.6, # Lower temperature for more focused responses | |
| top_p=0.85, # More focused generation | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| # Add stop sequences to prevent user/bot confusion | |
| early_stopping=True, | |
| ) | |
| logging.debug(f"Generated output shape: {outputs.shape}") | |
| # Extract only the new generated tokens | |
| generated_tokens = outputs[0][len(inputs.input_ids[0]):] | |
| bot_message = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
| logging.debug(f"Generated message length: {len(bot_message)} chars") | |
| # Clean up the response to prevent user/bot confusion | |
| bot_message = bot_message.strip() | |
| # Remove JoJo prefix if present | |
| if bot_message.startswith("JoJo:"): | |
| bot_message = bot_message[5:].strip() | |
| # Split on common separators and take only the first response | |
| separators = [ | |
| "\nUser:", "\nJoJo:", "\n\nUser:", "\n\nJoJo:", | |
| "User:", "\n", "Now,", "Then,", "Next,", "After that,", | |
| "The user", "User says", "I would", "They might", | |
| "For example", "Let me", "Here's" | |
| ] | |
| for sep in separators: | |
| if sep in bot_message: | |
| bot_message = bot_message.split(sep)[0].strip() | |
| break | |
| # Remove any trailing punctuation that suggests continuation | |
| while bot_message.endswith(('...', '..', '. .', ' .', ',')): | |
| bot_message = bot_message.rstrip('.,. ') | |
| # Ensure it's not empty after cleaning | |
| if not bot_message: | |
| bot_message = "I'd love to hear more about that! Can you tell me more?" | |
| # Safety check for bot response | |
| if not is_content_safe(bot_message): | |
| logging.warning(f"Blocked unsafe bot response: {bot_message[:50]}...") | |
| bot_message = get_safe_redirect_message() | |
| # Update conversation history | |
| updated_history = list(chat_history) | |
| updated_history.append((message, bot_message)) | |
| logging.info(f"Successfully generated safe response using local Qwen3-0.6B") | |
| return "", updated_history | |
| except Exception as e: | |
| logging.error(f"Error in Qwen3-0.6B generation: {str(e)}", exc_info=True) | |
| error_msg = f"I'm having trouble with the model right now. Please try again in a moment. (Error: {str(e)})" | |
| return "", list(chat_history) + [(message, error_msg)] | |
| def get_avatar_url(): | |
| return "https://api.dicebear.com/7.x/bottts/svg?seed=rabbit&backgroundColor=b6e3f4" | |
| custom_css = """ | |
| .compact-btn { | |
| padding: 0.75rem !important; | |
| font-size: 1rem !important; | |
| font-weight: 500; | |
| border-radius: 8px; | |
| background-color: #2f2f2f; | |
| color: white; | |
| transition: background-color 0.3s; | |
| } | |
| .compact-btn:hover { | |
| background-color: #444; | |
| } | |
| #voice-controls { | |
| margin-top: 1em; | |
| text-align: center; | |
| opacity: 0.5; | |
| font-size: 0.9rem; | |
| font-style: italic; | |
| } | |
| /* Hide share button */ | |
| .share-button, .share-btn, button[title="Share"] { | |
| display: none !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| gr.Markdown(""" | |
| # 🐰 JoJo - Your Speaking Buddy | |
| **Chat in English with JoJo — your kind and cheerful language partner.** | |
| Pick a topic, choose your level, and practice naturally. JoJo will guide you, ask questions, and gently correct you along the way! | |
| > 🛡️ **Safe Learning Environment**: This chatbot is designed specifically for children learning English. All conversations are filtered for safety and appropriateness. | |
| """) | |
| avatar = get_avatar_url() | |
| memory_flag = gr.State(value=True) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| height=400, | |
| bubble_full_width=True, | |
| show_copy_button=True, | |
| avatar_images=[None, avatar], | |
| scale=1, | |
| min_width=800 | |
| ) | |
| msg = gr.Textbox( | |
| placeholder="Say something to JoJo...", | |
| scale=4 | |
| ) | |
| with gr.Row(): | |
| submit = gr.Button("Send", variant="primary") | |
| clear = gr.Button("New Chat") | |
| gr.Markdown(""" | |
| <div id="voice-controls"> | |
| 🎤 Voice input and 🔈 playback coming soon! | |
| </div> | |
| """) | |
| with gr.Column(scale=1): | |
| gr.Markdown("""### 🎯 Conversation Settings""") | |
| topic = gr.Dropdown( | |
| choices=list(TOPIC_EXAMPLES.keys()), | |
| label="Select Topic", | |
| value="Daily Life" | |
| ) | |
| difficulty = gr.Dropdown( | |
| choices=["beginner", "intermediate", "advanced"], | |
| label="Select Difficulty", | |
| value="intermediate" | |
| ) | |
| gr.Markdown("""### 💬 Quick Starters""") | |
| starter_btn1 = gr.Button("Starter 1", scale=1, min_width=250, elem_classes="compact-btn") | |
| starter_btn2 = gr.Button("Starter 2", scale=1, min_width=250, elem_classes="compact-btn") | |
| starter_btn3 = gr.Button("Starter 3", scale=1, min_width=250, elem_classes="compact-btn") | |
| starter_buttons = [starter_btn1, starter_btn2, starter_btn3] | |
| def update_starters(selected_topic, selected_difficulty): | |
| examples = get_examples_for_topic(selected_topic, selected_difficulty) | |
| results = [examples[i] if i < len(examples) else "" for i in range(3)] | |
| return tuple(results) | |
| def use_starter(text: str, history: List[Tuple[str, str]], selected_topic: str, memory_flag: bool) -> Tuple[str, List[Tuple[str, str]]]: | |
| if not text: | |
| return "", history | |
| try: | |
| _, updated = respond(text, history, selected_topic, memory_flag) | |
| return "", updated | |
| except Exception as e: | |
| return "", history + [(text, f"Error: {str(e)}")] | |
| for btn in starter_buttons: | |
| btn.click(fn=use_starter, inputs=[btn, chatbot, topic, memory_flag], outputs=[msg, chatbot], queue=True) | |
| topic.change(fn=update_starters, inputs=[topic, difficulty], outputs=starter_buttons) | |
| difficulty.change(fn=update_starters, inputs=[topic, difficulty], outputs=starter_buttons) | |
| msg.submit(fn=respond, inputs=[msg, chatbot, topic, memory_flag], outputs=[msg, chatbot]) | |
| submit.click(fn=respond, inputs=[msg, chatbot, topic, memory_flag], outputs=[msg, chatbot]) | |
| clear.click(lambda: [], None, chatbot, queue=False) | |
| clear.click(lambda: "", None, msg, queue=False) | |
| default_starters = get_examples_for_topic("Daily Life", "intermediate") | |
| demo.load(fn=lambda: tuple(default_starters[:3]), outputs=starter_buttons, queue=False) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False) | |