import gradio as gr from transformers import TextStreamer from unsloth import FastLanguageModel import torch # Model Configuration max_seq_length = 2048 dtype = None model_name_or_path = "michailroussos/model_llama_8d" #model_name_or_path = "Natassaf/lora_model-llama-new" # Load Model and Tokenizer print("Loading model...") model, tokenizer = FastLanguageModel.from_pretrained( model_name=model_name_or_path, max_seq_length=max_seq_length, dtype=dtype, load_in_4bit=True, ) FastLanguageModel.for_inference(model) # Enable faster inference print("Model loaded successfully!") # Gradio Response Function from transformers import TextStreamer def respond(message, max_new_tokens, temperature, system_message="You are a helpful assistant. You should reply to the user's message without repeating the input."): try: # Prepare input messages messages = [{"role": "system", "content": system_message}] if system_message else [] messages.append({"role": "user", "content": message}) # Tokenize inputs input_ids = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to("cuda") # Ensure the input tensor has the correct dimensions if input_ids.dim() != 2: raise ValueError(f"`input_ids` must be a 2D tensor. Found shape: {input_ids.shape}") # Generate output directly with torch.no_grad(): # No need to track gradients for inference output = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, use_cache=True, ) promt = messages[0]['content'] promt += "assistant" print("[DEBUG] prompt with assistant:",promt) # Decode the generated tokens back to text generated_text = tokenizer.decode(output[0], skip_special_tokens=True) print("[DEBUG] Generated Text:", generated_text) start_pos = generated_text.find(promt) result_text = generated_text[start_pos + len(promt)+2:] print("[DEBUG] Result Text:", result_text) #print("[DEBUG] Generated Text:", generated_text) # Clean up the response by removing unwanted parts (e.g., system and user info) cleaned_response = "".join(generated_text.split("\n")[9:]) # Assuming the response ends at the last line # Debug: Show the cleaned response print("[DEBUG] Cleaned Response:", cleaned_response) return result_text except Exception as e: # Debug: Log errors print("[ERROR]", str(e)) return f"Error: {str(e)}" # Gradio UI demo = gr.Interface( fn=respond, inputs=[ gr.Textbox(label="Your Message", placeholder="Enter your prompt here..."), gr.Slider(minimum=1, maximum=512, step=1, value=128, label="Max New Tokens"), gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature"), gr.Textbox(label="System Message", placeholder="Optional system instructions."), ], outputs="text", title="LLama-based Chatbot", description="Interact with the model. Enter a prompt and receive a response.", ) if __name__ == "__main__": demo.launch(share=True)