Spaces:
Sleeping
Sleeping
import os | |
import uuid | |
import time | |
import json | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import modelscope_studio.components.antd as antd | |
import modelscope_studio.components.antdx as antdx | |
import modelscope_studio.components.base as ms | |
import modelscope_studio.components.pro as pro | |
# Define model paths | |
MODEL_PATHS = { | |
"LeCarnet-3M": "MaxLSB/LeCarnet-3M", | |
"LeCarnet-8M": "MaxLSB/LeCarnet-8M", | |
"LeCarnet-21M": "MaxLSB/LeCarnet-21M", | |
} | |
# Set HF token | |
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
if not hf_token: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.") | |
# Load tokenizer and model globally | |
tokenizer = None | |
model = None | |
def load_model(model_name: str): | |
global tokenizer, model | |
if model_name not in MODEL_PATHS: | |
raise ValueError(f"Unknown model: {model_name}") | |
print(f"Loading {model_name}...") | |
repo = MODEL_PATHS[model_name] | |
tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token) | |
model.eval() | |
print(f"{model_name} loaded.") | |
def generate_response(prompt, max_new_tokens=200): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response[len(prompt):].strip() | |
# CSS for styling chatbot header with avatar | |
css = """ | |
.chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header { | |
display: flex; | |
align-items: center; | |
} | |
.chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header img { | |
width: 20px; | |
height: 20px; | |
margin-right: 8px; | |
vertical-align: middle; | |
} | |
""" | |
# Default settings | |
DEFAULT_SETTINGS = { | |
"model": "LeCarnet-3M", | |
"sys_prompt": "", | |
} | |
# Initial state with one fixed conversation | |
state = gr.State({ | |
"conversation_id": "default", | |
"conversation_contexts": { | |
"default": { | |
"history": [], | |
"settings": DEFAULT_SETTINGS, | |
} | |
}, | |
}) | |
# Welcome message (optional) | |
def welcome_config(): | |
return { | |
"title": "LeCarnet Chatbot", | |
"description": "Start chatting below!", | |
"promptSuggestions": ["Hello", "Tell me a story", "How are you?"] | |
} | |
with gr.Blocks(css=css) as demo: | |
with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"): | |
# Right Column - Chat Interface | |
with antd.Col(flex=1, elem_style=dict(height="100%")): | |
with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"): | |
chatbot = pro.Chatbot( | |
elem_classes="chatbot-chat-messages", | |
height=0, | |
welcome_config=welcome_config() | |
) | |
with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion: | |
with ms.Slot("children"): | |
input = antdx.Sender(placeholder="Type your message here...") | |
current_state = state | |
def add_message(user_input, state_value): | |
history = state_value["conversation_contexts"]["default"]["history"] | |
settings = state_value["conversation_contexts"]["default"]["settings"] | |
selected_model = settings["model"] | |
# Add user message | |
history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())}) | |
yield {"chatbot": gr.update(value=history)} | |
# Start assistant response | |
history.append({ | |
"role": "assistant", | |
"content": [], | |
"key": str(uuid.uuid4()), | |
"header": f'<img src="/file=media/le-carnet.png" style="width:20px;height:20px;margin-right:8px;"> <span>{selected_model}</span>', | |
"loading": True | |
}) | |
yield {"chatbot": gr.update(value=history)} | |
try: | |
# Generate model response | |
prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"]) | |
response = generate_response(prompt) | |
# Update assistant message | |
history[-1]["content"] = [{"type": "text", "content": response}] | |
history[-1]["loading"] = False | |
yield {"chatbot": gr.update(value=history)} | |
except Exception as e: | |
history[-1]["content"] = [{ | |
"type": "text", | |
"content": f'<span style="color: red;">{str(e)}</span>' | |
}] | |
history[-1]["loading"] = False | |
yield {"chatbot": gr.update(value=history)} | |
input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot]) | |
# Load default model on startup | |
load_model(DEFAULT_SETTINGS["model"]) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=10).launch() |