LeCarnet-Demo / app.py
MaxLSB's picture
Update app.py
0826c27 verified
raw
history blame
5 kB
import os
import uuid
import time
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import modelscope_studio.components.antd as antd
import modelscope_studio.components.antdx as antdx
import modelscope_studio.components.base as ms
import modelscope_studio.components.pro as pro
# Define model paths
MODEL_PATHS = {
"LeCarnet-3M": "MaxLSB/LeCarnet-3M",
"LeCarnet-8M": "MaxLSB/LeCarnet-8M",
"LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}
# Set HF token
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not hf_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
# Load tokenizer and model globally
tokenizer = None
model = None
def load_model(model_name: str):
global tokenizer, model
if model_name not in MODEL_PATHS:
raise ValueError(f"Unknown model: {model_name}")
print(f"Loading {model_name}...")
repo = MODEL_PATHS[model_name]
tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token)
model.eval()
print(f"{model_name} loaded.")
def generate_response(prompt, max_new_tokens=200):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response[len(prompt):].strip()
# CSS for styling chatbot header with avatar
css = """
.chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header {
display: flex;
align-items: center;
}
.chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header img {
width: 20px;
height: 20px;
margin-right: 8px;
vertical-align: middle;
}
"""
# Default settings
DEFAULT_SETTINGS = {
"model": "LeCarnet-3M",
"sys_prompt": "",
}
# Initial state with one fixed conversation
state = gr.State({
"conversation_id": "default",
"conversation_contexts": {
"default": {
"history": [],
"settings": DEFAULT_SETTINGS,
}
},
})
# Welcome message (optional)
def welcome_config():
return {
"title": "LeCarnet Chatbot",
"description": "Start chatting below!",
"promptSuggestions": ["Hello", "Tell me a story", "How are you?"]
}
with gr.Blocks(css=css) as demo:
with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"):
# Right Column - Chat Interface
with antd.Col(flex=1, elem_style=dict(height="100%")):
with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"):
chatbot = pro.Chatbot(
elem_classes="chatbot-chat-messages",
height=0,
welcome_config=welcome_config()
)
with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion:
with ms.Slot("children"):
input = antdx.Sender(placeholder="Type your message here...")
current_state = state
def add_message(user_input, state_value):
history = state_value["conversation_contexts"]["default"]["history"]
settings = state_value["conversation_contexts"]["default"]["settings"]
selected_model = settings["model"]
# Add user message
history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())})
yield {"chatbot": gr.update(value=history)}
# Start assistant response
history.append({
"role": "assistant",
"content": [],
"key": str(uuid.uuid4()),
"header": f'<img src="/file=media/le-carnet.png" style="width:20px;height:20px;margin-right:8px;"> <span>{selected_model}</span>',
"loading": True
})
yield {"chatbot": gr.update(value=history)}
try:
# Generate model response
prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"])
response = generate_response(prompt)
# Update assistant message
history[-1]["content"] = [{"type": "text", "content": response}]
history[-1]["loading"] = False
yield {"chatbot": gr.update(value=history)}
except Exception as e:
history[-1]["content"] = [{
"type": "text",
"content": f'<span style="color: red;">{str(e)}</span>'
}]
history[-1]["loading"] = False
yield {"chatbot": gr.update(value=history)}
input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot])
# Load default model on startup
load_model(DEFAULT_SETTINGS["model"])
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10).launch()