import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import os import spaces # --- Configuration --- # IMPORTANT: Replace with the path to your locally downloaded model or a Hugging Face model ID. # Examples: # LOCAL_MODEL_PATH = "/path/to/your/downloaded/qwen-1.5b-instruct" # HUGGINGFACE_MODEL_ID = "Qwen/Qwen1.5-1.8B-Chat" # For a smaller Qwen model for local testing HUGGINGFACE_MODEL_ID = "HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd" # You might need to adjust TORCH_DTYPE based on your GPU and model support # torch.float16 (FP16) is common for inference, torch.bfloat16 for newer GPUs TORCH_DTYPE = torch.float16 # or torch.bfloat16 or torch.float32 # Generation parameters (can be adjusted for different response styles) MAX_NEW_TOKENS = 512 DO_SAMPLE = True TEMPERATURE = 0.7 TOP_K = 50 TOP_P = 0.95 # --- Global variables for models and tokenizers --- tokenizer = None model = None # --- Load Models and Tokenizers Function --- @spaces.GPU def load_model_and_tokenizer(): """ Loads the language model and tokenizer from Hugging Face Hub or a local path. This function will be called once when the Gradio app starts up. """ global tokenizer, model if tokenizer is not None and model is not None: print("Model and tokenizer already loaded.") return print(f"Loading tokenizer from: {HUGGINGFACE_MODEL_ID}") try: tokenizer = AutoTokenizer.from_pretrained(HUGGINGFACE_MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Set tokenizer.pad_token to tokenizer.eos_token ({tokenizer.pad_token_id})") print(f"Loading model from: {HUGGINGFACE_MODEL_ID}...") model = AutoModelForCausalLM.from_pretrained( HUGGINGFACE_MODEL_ID, torch_dtype=TORCH_DTYPE, device_map="auto" # Automatically maps model to GPU if available, else CPU ) model.eval() # Set model to evaluation mode print("Model loaded successfully.") except Exception as e: print(f"Error loading model or tokenizer: {e}") print("Please ensure the model ID is correct and you have an internet connection for initial download, or the local path is valid.") tokenizer = None model = None raise RuntimeError("Failed to load model. Check your model ID/path and internet connection.") # --- Generate Response Function --- @spaces.GPU def generate_response( message: str, # Current user message history: list # Gradio Chatbot history format (list of dictionaries with 'role' and 'content') ) -> list: # Returns updated history for the Chatbot """ Generates a text response from the loaded model based on user input and chat history. """ global tokenizer, model # Initialize models if not already loaded if tokenizer is None or model is None: load_model_and_tokenizer() if tokenizer is None or model is None: # Check again in case loading failed # history.append([message, "Error: Chatbot model not loaded. Please check logs."]) # For 'messages' type history, append a dictionary history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": "Error: Chatbot model not loaded. Please check logs."}) return history # Format messages for the model's chat template (e.g., for Instruct models) # The 'history' now directly contains dictionaries if type='messages' is used. messages = history # Use history directly as it's already in the correct format messages.append({"role": "user", "content": message}) # Add current user message # Apply the chat template and tokenize try: input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) except Exception as e: print(f"Error applying chat template: {e}") # Fallback if chat template fails (e.g., for non-chat models) # Reconstruct input_text for models without explicit chat templates input_text = "" for item in history: if item["role"] == "user": input_text += f"User: {item['content']}\n" elif item["role"] == "assistant": input_text += f"Assistant: {item['content']}\n" input_text += f"User: {message}\nAssistant:" input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device) # Generate response with torch.no_grad(): # Disable gradient calculations for inference output_ids = model.generate( input_ids, max_new_tokens=MAX_NEW_TOKENS, do_sample=DO_SAMPLE, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P, pad_token_id=tokenizer.eos_token_id # Important for generation to stop cleanly ) # Decode the generated text, excluding the input prompt part generated_token_ids = output_ids[0][input_ids.shape[-1]:] generated_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True).strip() # --- Update Chat History --- # Append the latest generated response to the history with its role history.append({"role": "assistant", "content": generated_text}) return history # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown( """ # HuggingFaceH4/Qwen2.5-1.5B-Instruct-gkd chat bot Type your message below and chat with the model! """ ) # Set type='messages' for the chatbot to use OpenAI-style dictionaries chatbot = gr.Chatbot(label="Conversation", type='messages') with gr.Row(): text_input = gr.Textbox( label="Your message", placeholder="Type your message here...", scale=4 ) submit_button = gr.Button("Send", scale=1) # Link the text input and button to the generation function # Note: 'inputs' will be current message and the full history (as 'messages' type) # 'outputs' will be the updated full history submit_button.click( fn=generate_response, inputs=[text_input, chatbot], # text_input is the new message, chatbot is the history outputs=[chatbot], queue=True # Queue requests for better concurrency ) text_input.submit( # Also trigger on Enter key fn=generate_response, inputs=[text_input, chatbot], outputs=[chatbot], queue=True ) # Clear button def clear_chat(): # When type='messages', the clear function should return an empty list for history # and an empty string for the text input. return [], "" clear_button = gr.Button("Clear Chat") clear_button.click(clear_chat, inputs=None, outputs=[chatbot, text_input]) # Load the model when the app starts. This will ensure it's ready when the first request comes in. load_model_and_tokenizer() # Launch the Gradio app #demo.queue().launch() # For local development, use launch() demo.queue().launch(server_name="0.0.0.0")