File size: 4,354 Bytes
31391ab
afec331
31391ab
afec331
86cc964
afec331
 
e5e3988
afec331
31391ab
 
 
 
 
 
 
 
 
 
 
 
9015f33
31391ab
afec331
 
 
31391ab
766b6ce
9015f33
afec331
 
 
2372150
 
afec331
 
 
 
 
 
 
 
9015f33
afec331
 
31391ab
afec331
31391ab
 
 
 
afec331
31391ab
 
 
 
 
 
 
 
 
 
afec331
31391ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afec331
31391ab
afec331
 
 
31391ab
 
 
 
 
 
 
 
 
 
 
afec331
31391ab
 
 
 
 
 
 
 
 
 
 
9015f33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import re
import torch
from threading import Thread
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
CONTEXT_LENGTH = 4096

# Add special tokens for thinking process
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.add_special_tokens({
    "additional_special_tokens": ["<think>", "</think>"]
})

model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
model.resize_token_embeddings(len(tokenizer))

def predict(message, history, show_thinking, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    stop_tokens = ["<|endoftext|>", "<|im_end|>", "|im_end|", "</think>"]
    instruction = f'<|im_start|>system\n{system_prompt}\n<|im_end|>\n'
    
    # Format chat history
    for user, assistant in history:
        instruction += f'<|im_start|>user\n{user}\n<|im_end|>\n<|im_start|>assistant\n{assistant}\n<|im_end|>\n'
    instruction += f'<|im_start|>user\n{message}\n<|im_end|>\n<|im_start|>assistant\n'

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(instruction, return_tensors="pt", truncation=True, max_length=CONTEXT_LENGTH)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    generate_kwargs = dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )
    
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    thinking_buffer = []
    in_thinking = False
    current_chunk = ""

    for new_token in streamer:
        current_chunk += new_token
        
        # Check for thinking tags
        if "<think>" in current_chunk and not in_thinking:
            in_thinking = True
            pre, _, post = current_chunk.partition("<think>")
            if pre:
                outputs.append(pre)
                yield _clean_output("".join(outputs), show_thinking)
            current_chunk = post

        if "</think>" in current_chunk and in_thinking:
            in_thinking = False
            pre, _, post = current_chunk.partition("</think>")
            thinking_buffer.append(pre)
            if show_thinking:
                outputs.extend(thinking_buffer)
            thinking_buffer = []
            current_chunk = post

        if in_thinking:
            thinking_buffer.append(current_chunk)
            if show_thinking:
                outputs.append(current_chunk)
                yield _clean_output("".join(outputs), show_thinking)
            current_chunk = ""
        else:
            if current_chunk:
                outputs.append(current_chunk)
                yield _clean_output("".join(outputs), show_thinking)
                current_chunk = ""

def _clean_output(text: str, show_thinking: bool) -> str:
    # Remove residual tags and format thinking content
    text = re.sub(r'\s*<think>\s*', '\n\n*Thinking:* ', text)
    text = re.sub(r'\s*</think>\s*', ' ', text)
    text = re.sub(r'(\*Thinking:\*)(?! )', r'\1 ', text)
    return text.strip()

# Create interface with toggle
gr.ChatInterface(
    predict,
    additional_inputs=[
        gr.Checkbox(value=True, label="๐Ÿ” Show Thinking Process"),
        gr.Textbox(
            "You are an AI assistant. First analyze requests using <think> tags, then provide answers. "
            "Put all reasoning between <think> and </think> tags.",
            label="System Prompt"
        ),
        gr.Slider(0, 1, 0.6, label="๐ŸŒก๏ธ Temperature"),
        gr.Slider(0, 4096, 512, label="๐Ÿ“ Max New Tokens"),
        gr.Slider(1, 80, 40, label="๐ŸŽ›๏ธ Top K"),
        gr.Slider(0.1, 2.0, 1.1, label="๐Ÿ”„ Repetition Penalty"),
        gr.Slider(0, 1, 0.95, label="๐Ÿงฎ Top P"),
    ],
    css="""
    .thinking { 
        color: #666;
        font-style: italic;
        border-left: 3px solid #ddd;
        padding-left: 1em;
        margin: 0.5em 0;
    }
    """,
    title="DeepSeek AI Assistant with Reasoning",
    description="Toggle the 'Show Thinking Process' checkbox to view/hide the model's internal reasoning"
).queue().launch()