Spaces:
Running
Running
######################################## | |
# app.py | |
######################################## | |
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
# We define a cache to load pipelines for each model only once. | |
def load_text_generation_pipeline(model_name: str): | |
""" | |
Loads a text-generation pipeline from the Hugging Face Hub. | |
""" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype="auto", # or torch.float16 if GPU is available | |
device_map="auto" # automatically map layers to available GPU(s) | |
) | |
text_generation = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer | |
) | |
return text_generation | |
def generate_response( | |
text_generation, | |
system_prompt: str, | |
conversation_history: list, | |
user_query: str, | |
max_new_tokens: int, | |
temperature: float, | |
top_p: float | |
): | |
""" | |
Generates a response from the language model given the system prompt, | |
conversation history, and user query with specified parameters. | |
""" | |
# Construct a prompt that includes the system role, conversation history, and the new user input. | |
# Adjust format depending on your model's instructions format. | |
# Here we do a simple approach: system prompt + turn-by-turn conversation. | |
full_prompt = system_prompt.strip() | |
for (speaker, text) in conversation_history: | |
if speaker == "user": | |
full_prompt += f"\nUser: {text}" | |
else: | |
full_prompt += f"\nAssistant: {text}" | |
# Add the new user query | |
full_prompt += f"\nUser: {user_query}\nAssistant:" | |
# Use the pipeline to generate text | |
outputs = text_generation( | |
full_prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True | |
) | |
# The pipeline returns a list of generated sequences; get the text from the first one | |
generated_text = outputs[0]["generated_text"] | |
# Extract just the new answer part from the generated text | |
# Since we appended "Assistant:" at the end, the model's response is everything after that | |
answer = generated_text.split("Assistant:")[-1].strip() | |
return answer | |
def main(): | |
st.title("Streamlit Chatbot with Model Selection") | |
st.markdown( | |
""" | |
**System message**: You are a friendly Chatbot created by [ruslanmv.com](https://ruslanmv.com) | |
Below you can select the model, adjust parameters, and begin chatting! | |
""" | |
) | |
# Sidebar for model selection and parameters | |
st.sidebar.header("Select Model & Parameters") | |
model_name = st.sidebar.selectbox( | |
"Choose a model:", | |
[ | |
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", | |
"deepseek-ai/DeepSeek-R1", | |
"deepseek-ai/DeepSeek-R1-Zero" | |
] | |
) | |
max_new_tokens = st.sidebar.slider( | |
"Max new tokens", | |
min_value=1, | |
max_value=4000, | |
value=1024, | |
step=1 | |
) | |
temperature = st.sidebar.slider( | |
"Temperature", | |
min_value=0.1, | |
max_value=4.0, | |
value=1.0, | |
step=0.1 | |
) | |
top_p = st.sidebar.slider( | |
"Top-p (nucleus sampling)", | |
min_value=0.1, | |
max_value=1.0, | |
value=0.9, | |
step=0.05 | |
) | |
# The system "role" content | |
system_message = ( | |
"You are a friendly Chatbot created by ruslanmv.com. " | |
"You answer user questions in a concise and helpful way." | |
) | |
# Load the chosen model | |
text_generation_pipeline = load_text_generation_pipeline(model_name) | |
# We'll keep conversation history in session_state | |
if "conversation" not in st.session_state: | |
st.session_state["conversation"] = [] # List of tuples (speaker, text) | |
# Display conversation so far | |
# Each element in st.session_state["conversation"] is ("user" or "assistant", message_text) | |
for speaker, text in st.session_state["conversation"]: | |
if speaker == "user": | |
st.markdown(f"<div style='text-align:left; color:blue'><strong>User:</strong> {text}</div>", unsafe_allow_html=True) | |
else: | |
st.markdown(f"<div style='text-align:left; color:green'><strong>Assistant:</strong> {text}</div>", unsafe_allow_html=True) | |
# User input text box | |
user_input = st.text_input("Your message", "") | |
# When user hits "Send" | |
if st.button("Send"): | |
if user_input.strip(): | |
# 1) Add user query to conversation | |
st.session_state["conversation"].append(("user", user_input.strip())) | |
# 2) Generate a response | |
with st.spinner("Thinking..."): | |
answer = generate_response( | |
text_generation=text_generation_pipeline, | |
system_prompt=system_message, | |
conversation_history=st.session_state["conversation"], | |
user_query=user_input.strip(), | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p | |
) | |
# 3) Add assistant answer to conversation | |
st.session_state["conversation"].append(("assistant", answer)) | |
# 4) Rerun to display | |
st.experimental_rerun() | |
# Optional: Provide a button to clear the conversation | |
if st.button("Clear Conversation"): | |
st.session_state["conversation"] = [] | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() | |