|
|
|
|
|
|
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
|
|
|
|
@st.cache_resource |
|
def load_text_generation_pipeline(model_name: str): |
|
""" |
|
Loads a text-generation pipeline from the Hugging Face Hub. |
|
""" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
text_generation = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer |
|
) |
|
return text_generation |
|
|
|
def generate_response( |
|
text_generation, |
|
system_prompt: str, |
|
conversation_history: list, |
|
user_query: str, |
|
max_new_tokens: int, |
|
temperature: float, |
|
top_p: float |
|
): |
|
""" |
|
Generates a response from the language model given the system prompt, |
|
conversation history, and user query with specified parameters. |
|
""" |
|
|
|
|
|
|
|
full_prompt = system_prompt.strip() |
|
for (speaker, text) in conversation_history: |
|
if speaker == "user": |
|
full_prompt += f"\nUser: {text}" |
|
else: |
|
full_prompt += f"\nAssistant: {text}" |
|
|
|
full_prompt += f"\nUser: {user_query}\nAssistant:" |
|
|
|
|
|
outputs = text_generation( |
|
full_prompt, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True |
|
) |
|
|
|
generated_text = outputs[0]["generated_text"] |
|
|
|
|
|
|
|
answer = generated_text.split("Assistant:")[-1].strip() |
|
return answer |
|
|
|
def main(): |
|
st.title("Streamlit Chatbot with Model Selection") |
|
st.markdown( |
|
""" |
|
**System message**: You are a friendly Chatbot created by [ruslanmv.com](https://ruslanmv.com) |
|
Below you can select the model, adjust parameters, and begin chatting! |
|
""" |
|
) |
|
|
|
|
|
st.sidebar.header("Select Model & Parameters") |
|
model_name = st.sidebar.selectbox( |
|
"Choose a model:", |
|
[ |
|
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", |
|
"deepseek-ai/DeepSeek-R1", |
|
"deepseek-ai/DeepSeek-R1-Zero" |
|
] |
|
) |
|
|
|
max_new_tokens = st.sidebar.slider( |
|
"Max new tokens", |
|
min_value=1, |
|
max_value=4000, |
|
value=1024, |
|
step=1 |
|
) |
|
|
|
temperature = st.sidebar.slider( |
|
"Temperature", |
|
min_value=0.1, |
|
max_value=4.0, |
|
value=1.0, |
|
step=0.1 |
|
) |
|
|
|
top_p = st.sidebar.slider( |
|
"Top-p (nucleus sampling)", |
|
min_value=0.1, |
|
max_value=1.0, |
|
value=0.9, |
|
step=0.05 |
|
) |
|
|
|
|
|
system_message = ( |
|
"You are a friendly Chatbot created by ruslanmv.com. " |
|
"You answer user questions in a concise and helpful way." |
|
) |
|
|
|
|
|
text_generation_pipeline = load_text_generation_pipeline(model_name) |
|
|
|
|
|
if "conversation" not in st.session_state: |
|
st.session_state["conversation"] = [] |
|
|
|
|
|
|
|
for speaker, text in st.session_state["conversation"]: |
|
if speaker == "user": |
|
st.markdown(f"<div style='text-align:left; color:blue'><strong>User:</strong> {text}</div>", unsafe_allow_html=True) |
|
else: |
|
st.markdown(f"<div style='text-align:left; color:green'><strong>Assistant:</strong> {text}</div>", unsafe_allow_html=True) |
|
|
|
|
|
user_input = st.text_input("Your message", "") |
|
|
|
|
|
if st.button("Send"): |
|
if user_input.strip(): |
|
|
|
st.session_state["conversation"].append(("user", user_input.strip())) |
|
|
|
with st.spinner("Thinking..."): |
|
answer = generate_response( |
|
text_generation=text_generation_pipeline, |
|
system_prompt=system_message, |
|
conversation_history=st.session_state["conversation"], |
|
user_query=user_input.strip(), |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_p=top_p |
|
) |
|
|
|
st.session_state["conversation"].append(("assistant", answer)) |
|
|
|
st.experimental_rerun() |
|
|
|
|
|
if st.button("Clear Conversation"): |
|
st.session_state["conversation"] = [] |
|
st.experimental_rerun() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|