Mental_chat / app.py
bobpopboom's picture
deep seek xD
2bd985b verified
raw
history blame
2.32 kB
import gradio as gr
from transformers import AutoTokenizer
from llama_cpp import Llama
import torch
# Configuration
MODEL_PATH = "./TinyLlama-Friendly-Psychotherapist.Q4_K_S.gguf"
MODEL_REPO = "thrishala/mental_health_chatbot"
try:
# 1. Load the tokenizer from the original model repo
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 4096
# 2. Load the GGUF model with llama-cpp-python
llm = Llama(
model_path=MODEL_PATH,
n_ctx=2048, # Context window size
n_threads=4, # CPU threads
n_gpu_layers=33 if torch.cuda.is_available() else 0, # GPU layers
)
except Exception as e:
print(f"Error loading model: {e}")
exit()
def generate_text_streaming(prompt, max_new_tokens=128):
# Tokenize using HF tokenizer
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=4096
)
# Convert to string for llama.cpp
full_prompt = tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
# Create generator
stream = llm.create_completion(
prompt=full_prompt,
max_tokens=max_new_tokens,
temperature=0.7,
stream=True,
stop=["User:", "###"], # Stop sequences
)
generated_text = ""
for output in stream:
chunk = output["choices"][0]["text"]
generated_text += chunk
yield generated_text
def respond(message, history, system_message, max_tokens):
# Build prompt with history
prompt = f"{system_message}\n"
for user_msg, bot_msg in history:
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
try:
for chunk in generate_text_streaming(prompt, max_tokens):
yield chunk
except Exception as e:
print(f"Error: {e}")
yield "An error occurred during generation."
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
],
)
if __name__ == "__main__":
demo.launch()