File size: 5,633 Bytes
3b55d4b
 
 
9ab0176
3b55d4b
 
 
 
 
 
 
 
 
 
 
 
 
07a96e5
3b55d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1472595
3b55d4b
 
 
 
 
 
 
 
 
 
 
f70fc29
3b55d4b
 
9ab0176
3b55d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
########################################
# app.py
########################################
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# We define a cache to load pipelines for each model only once.
@st.cache_resource
def load_text_generation_pipeline(model_name: str):
    """
    Loads a text-generation pipeline from the Hugging Face Hub.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",       # or torch.float16 if GPU is available
        device_map="auto"         # automatically map layers to available GPU(s)
    )
    text_generation = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer
    )
    return text_generation

def generate_response(
    text_generation,
    system_prompt: str,
    conversation_history: list,
    user_query: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float
):
    """
    Generates a response from the language model given the system prompt,
    conversation history, and user query with specified parameters.
    """
    # Construct a prompt that includes the system role, conversation history, and the new user input.
    # Adjust format depending on your model's instructions format.
    # Here we do a simple approach: system prompt + turn-by-turn conversation.
    full_prompt = system_prompt.strip()
    for (speaker, text) in conversation_history:
        if speaker == "user":
            full_prompt += f"\nUser: {text}"
        else:
            full_prompt += f"\nAssistant: {text}"
    # Add the new user query
    full_prompt += f"\nUser: {user_query}\nAssistant:"

    # Use the pipeline to generate text
    outputs = text_generation(
        full_prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True
    )
    # The pipeline returns a list of generated sequences; get the text from the first one
    generated_text = outputs[0]["generated_text"]
    
    # Extract just the new answer part from the generated text
    # Since we appended "Assistant:" at the end, the model's response is everything after that
    answer = generated_text.split("Assistant:")[-1].strip()
    return answer

def main():
    st.title("Streamlit Chatbot with Model Selection")
    st.markdown(
        """
        **System message**: You are a friendly Chatbot created by [ruslanmv.com](https://ruslanmv.com)  
        Below you can select the model, adjust parameters, and begin chatting!
        """
    )

    # Sidebar for model selection and parameters
    st.sidebar.header("Select Model & Parameters")
    model_name = st.sidebar.selectbox(
        "Choose a model:",
        [
            "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
            "deepseek-ai/DeepSeek-R1",
            "deepseek-ai/DeepSeek-R1-Zero"
        ]
    )

    max_new_tokens = st.sidebar.slider(
        "Max new tokens",
        min_value=1,
        max_value=4000,
        value=1024,
        step=1
    )

    temperature = st.sidebar.slider(
        "Temperature",
        min_value=0.1,
        max_value=4.0,
        value=1.0,
        step=0.1
    )

    top_p = st.sidebar.slider(
        "Top-p (nucleus sampling)",
        min_value=0.1,
        max_value=1.0,
        value=0.9,
        step=0.05
    )

    # The system "role" content
    system_message = (
        "You are a friendly Chatbot created by ruslanmv.com. "
        "You answer user questions in a concise and helpful way."
    )

    # Load the chosen model
    text_generation_pipeline = load_text_generation_pipeline(model_name)

    # We'll keep conversation history in session_state
    if "conversation" not in st.session_state:
        st.session_state["conversation"] = []  # List of tuples (speaker, text)

    # Display conversation so far
    # Each element in st.session_state["conversation"] is ("user" or "assistant", message_text)
    for speaker, text in st.session_state["conversation"]:
        if speaker == "user":
            st.markdown(f"<div style='text-align:left; color:blue'><strong>User:</strong> {text}</div>", unsafe_allow_html=True)
        else:
            st.markdown(f"<div style='text-align:left; color:green'><strong>Assistant:</strong> {text}</div>", unsafe_allow_html=True)

    # User input text box
    user_input = st.text_input("Your message", "")

    # When user hits "Send"
    if st.button("Send"):
        if user_input.strip():
            # 1) Add user query to conversation
            st.session_state["conversation"].append(("user", user_input.strip()))
            # 2) Generate a response
            with st.spinner("Thinking..."):
                answer = generate_response(
                    text_generation=text_generation_pipeline,
                    system_prompt=system_message,
                    conversation_history=st.session_state["conversation"],
                    user_query=user_input.strip(),
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p
                )
            # 3) Add assistant answer to conversation
            st.session_state["conversation"].append(("assistant", answer))
            # 4) Rerun to display
            st.experimental_rerun()

    # Optional: Provide a button to clear the conversation
    if st.button("Clear Conversation"):
        st.session_state["conversation"] = []
        st.experimental_rerun()


if __name__ == "__main__":
    main()