michailroussos
changed back model
b751030
import gradio as gr
from transformers import TextStreamer
from unsloth import FastLanguageModel
import torch
# Model Configuration
max_seq_length = 2048
dtype = None
model_name_or_path = "michailroussos/model_llama_8d"
#model_name_or_path = "Natassaf/lora_model-llama-new"
# Load Model and Tokenizer
print("Loading model...")
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name_or_path,
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model) # Enable faster inference
print("Model loaded successfully!")
# Gradio Response Function
from transformers import TextStreamer
def respond(message, max_new_tokens, temperature, system_message="You are a helpful assistant. You should reply to the user's message without repeating the input."):
try:
# Prepare input messages
messages = [{"role": "system", "content": system_message}] if system_message else []
messages.append({"role": "user", "content": message})
# Tokenize inputs
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to("cuda")
# Ensure the input tensor has the correct dimensions
if input_ids.dim() != 2:
raise ValueError(f"`input_ids` must be a 2D tensor. Found shape: {input_ids.shape}")
# Generate output directly
with torch.no_grad(): # No need to track gradients for inference
output = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
use_cache=True,
)
promt = messages[0]['content']
promt += "assistant"
print("[DEBUG] prompt with assistant:",promt)
# Decode the generated tokens back to text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print("[DEBUG] Generated Text:", generated_text)
start_pos = generated_text.find(promt)
result_text = generated_text[start_pos + len(promt)+2:]
print("[DEBUG] Result Text:", result_text)
#print("[DEBUG] Generated Text:", generated_text)
# Clean up the response by removing unwanted parts (e.g., system and user info)
cleaned_response = "".join(generated_text.split("\n")[9:]) # Assuming the response ends at the last line
# Debug: Show the cleaned response
print("[DEBUG] Cleaned Response:", cleaned_response)
return result_text
except Exception as e:
# Debug: Log errors
print("[ERROR]", str(e))
return f"Error: {str(e)}"
# Gradio UI
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(label="Your Message", placeholder="Enter your prompt here..."),
gr.Slider(minimum=1, maximum=512, step=1, value=128, label="Max New Tokens"),
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature"),
gr.Textbox(label="System Message", placeholder="Optional system instructions."),
],
outputs="text",
title="LLama-based Chatbot",
description="Interact with the model. Enter a prompt and receive a response.",
)
if __name__ == "__main__":
demo.launch(share=True)