File size: 2,960 Bytes
77b226f
c215361
 
1757e26
77b226f
c215361
 
77b226f
65b9bec
77b226f
1757e26
c215361
 
1757e26
c215361
77b226f
c215361
 
 
1757e26
 
 
77b226f
1757e26
c215361
 
1757e26
c215361
 
 
 
 
 
 
146ed17
c215361
 
 
 
 
 
 
 
 
 
 
 
 
 
77b226f
1757e26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77b226f
 
 
 
1757e26
 
 
77b226f
1757e26
77b226f
 
 
 
1757e26
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from transformers import AutoTokenizer
import ctranslate2
import torch

# Determine device (ctranslate2 handles device placement internally)
device = "cuda" if torch.cuda.is_available() else "cpu"  # Still useful for other ops

model_path =  "mradermacher/TinyLlama-Friendly-Psychotherapist-GGUF/TinyLlama-Friendly-Psychotherapist.Q4_K_S.gguf"

try:
    # 1. Load the tokenizer (same as before)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.model_max_length = 4096

    # 2. Load the ctranslate2 model
    ct_model = ctranslate2.Translator(model_path)  # Load the GGUF model
    ct_model.eval()
except Exception as e:
    print(f"Error loading model: {e}")
    exit()

def generate_text_streaming(prompt, max_new_tokens=128):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096).to(device)

    generated_tokens = []

    for _ in range(max_new_tokens):
        # ctranslate2 generation (adjust as needed)
        outputs = ct_model.translate_batch(
            inputs.input_ids.tolist(),  # ctranslate2 needs list of token ids
            max_length=1, # Generate one token at a time
            beam_size=1, # Greedy decoding
        )

        new_token_id = outputs[0][0][-1]  # Extract the generated token ID
        new_token = tokenizer.decode(new_token_id, skip_special_tokens=True)

        if new_token_id == tokenizer.eos_token_id:
            break

        generated_tokens.append(new_token_id)

        current_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
        yield current_text

        inputs["input_ids"] = torch.cat([inputs["input_ids"], torch.tensor([[new_token_id]], device=inputs["input_ids"].device)], dim=-1)
        inputs["attention_mask"] = torch.cat([inputs["attention_mask"], torch.ones(1, 1, device=inputs["attention_mask"].device)], dim=-1)

def respond(message, history, system_message, max_tokens):
    # Build prompt with full history
    prompt = f"{system_message}\n"
    for user_msg, bot_msg in history:
        prompt += f"User: {user_msg}\nAssistant: {bot_msg}\n"
    prompt += f"User: {message}\nAssistant:"
    
    # Keep track of the full response
    full_response = ""
    
    try:
        for token_chunk in generate_text_streaming(prompt, max_tokens):
            # Update the full response and yield incremental changes
            full_response = token_chunk
            yield full_response
            
    except Exception as e:
        print(f"Error during generation: {e}")
        yield "An error occurred."

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly and helpful mental health chatbot.",
            label="System message",
        ),
        gr.Slider(minimum=1, maximum=512, value=128, step=1, label="Max new tokens"),
    ],
)

if __name__ == "__main__":
    demo.launch()