Spaces:
Sleeping
Sleeping
File size: 2,960 Bytes
77b226f c215361 1757e26 77b226f c215361 77b226f 65b9bec 77b226f 1757e26 c215361 1757e26 c215361 77b226f c215361 1757e26 77b226f 1757e26 c215361 1757e26 c215361 146ed17 c215361 77b226f 1757e26 77b226f 1757e26 77b226f 1757e26 77b226f 1757e26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
from transformers import AutoTokenizer
import ctranslate2
import torch
# Determine device (ctranslate2 handles device placement internally)
device = "cuda" if torch.cuda.is_available() else "cpu" # Still useful for other ops
model_path = "mradermacher/TinyLlama-Friendly-Psychotherapist-GGUF/TinyLlama-Friendly-Psychotherapist.Q4_K_S.gguf"
try:
# 1. Load the tokenizer (same as before)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 4096
# 2. Load the ctranslate2 model
ct_model = ctranslate2.Translator(model_path) # Load the GGUF model
ct_model.eval()
except Exception as e:
print(f"Error loading model: {e}")
exit()
def generate_text_streaming(prompt, max_new_tokens=128):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(device)
generated_tokens = []
for _ in range(max_new_tokens):
# ctranslate2 generation (adjust as needed)
outputs = ct_model.translate_batch(
inputs.input_ids.tolist(), # ctranslate2 needs list of token ids
max_length=1, # Generate one token at a time
beam_size=1, # Greedy decoding
)
new_token_id = outputs[0][0][-1] # Extract the generated token ID
new_token = tokenizer.decode(new_token_id, skip_special_tokens=True)
if new_token_id == tokenizer.eos_token_id:
break
generated_tokens.append(new_token_id)
current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
yield current_text
inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([[new_token_id]], device=inputs["input_ids"].device)], dim=-1)
inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.ones(1, 1, device=inputs["attention_mask"].device)], dim=-1)
def respond(message, history, system_message, max_tokens):
# Build prompt with full history
prompt = f"{system_message}\n"
for user_msg, bot_msg in history:
prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
prompt += f"User: {message}\nAssistant:"
# Keep track of the full response
full_response = ""
try:
for token_chunk in generate_text_streaming(prompt, max_tokens):
# Update the full response and yield incremental changes
full_response = token_chunk
yield full_response
except Exception as e:
print(f"Error during generation: {e}")
yield "An error occurred."
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly and helpful mental health chatbot.",
label="System message",
),
gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
],
)
if __name__ == "__main__":
demo.launch() |