Spaces:
Sleeping
Sleeping
File size: 2,323 Bytes
77b226f c215361 2bd985b 1757e26 77b226f 2bd985b 77b226f 1757e26 2bd985b 1757e26 c215361 77b226f 2bd985b 1757e26 77b226f 1757e26 2bd985b c215361 2bd985b 77b226f 1757e26 2bd985b 1757e26 2bd985b 1757e26 2bd985b 77b226f 1757e26 77b226f 1757e26 77b226f 1757e26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import gradio as gr
from transformers import AutoTokenizer
from llama_cpp import Llama
import torch
# Configuration
MODEL_PATH = "./TinyLlama-Friendly-Psychotherapist.Q4_K_S.gguf"
MODEL_REPO = "thrishala/mental_health_chatbot"
try:
# 1. Load the tokenizer from the original model repo
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 4096
# 2. Load the GGUF model with llama-cpp-python
llm = Llama(
model_path=MODEL_PATH,
n_ctx=2048, # Context window size
n_threads=4, # CPU threads
n_gpu_layers=33 if torch.cuda.is_available() else 0, # GPU layers
)
except Exception as e:
print(f"Error loading model: {e}")
exit()
def generate_text_streaming(prompt, max_new_tokens=128):
# Tokenize using HF tokenizer
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=4096
)
# Convert to string for llama.cpp
full_prompt = tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
# Create generator
stream = llm.create_completion(
prompt=full_prompt,
max_tokens=max_new_tokens,
temperature=0.7,
stream=True,
stop=["User:", "###"], # Stop sequences
)
generated_text = ""
for output in stream:
chunk = output["choices"][0]["text"]
generated_text += chunk
yield generated_text
def respond(message, history, system_message, max_tokens):
# Build prompt with history
prompt = f"{system_message}\n"
for user_msg, bot_msg in history:
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
try:
for chunk in generate_text_streaming(prompt, max_tokens):
yield chunk
except Exception as e:
print(f"Error: {e}")
yield "An error occurred during generation."
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
],
)
if __name__ == "__main__":
demo.launch() |