LeCarnet-Demo / app.py
MaxLSB's picture
Update app.py
d91c9af verified
raw
history blame
3.98 kB
import os
import uuid
import time
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import modelscope_studio.components.antd as antd
import modelscope_studio.components.base as ms
import modelscope_studio.components.pro as pro
MODEL_PATHS = {
"LeCarnet-3M": "MaxLSB/LeCarnet-3M",
"LeCarnet-8M": "MaxLSB/LeCarnet-8M",
"LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not hf_token:
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")
tokenizer = None
model = None
def load_model(model_name: str):
global tokenizer, model
if model_name not in MODEL_PATHS:
raise ValueError(f"Unknown model: {model_name}")
print(f"Loading {model_name}...")
repo = MODEL_PATHS[model_name]
tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token)
model.eval()
print(f"{model_name} loaded.")
def generate_response(prompt, max_new_tokens=200):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response[len(prompt):].strip()
DEFAULT_SETTINGS = {
"model": "LeCarnet-3M",
"sys_prompt": "",
}
# Initial state with one fixed conversation
state = gr.State({
"conversation_id": "default",
"conversation_contexts": {
"default": {
"history": [],
"settings": DEFAULT_SETTINGS,
}
},
})
with gr.Blocks(css=css) as demo:
with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"):
# Right Column - Chat Interface
with antd.Col(flex=1, elem_style=dict(height="100%")):
with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"):
chatbot = pro.Chatbot(elem_classes="chatbot-chat-messages", height=0)
with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion:
with ms.Slot("children"):
input = antdx.Sender(placeholder="Type your message here...")
# Internal State
current_state = state
def add_message(user_input, state_value):
history = state_value["conversation_contexts"]["default"]["history"]
settings = state_value["conversation_contexts"]["default"]["settings"]
selected_model = settings["model"]
# Add user message
history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())})
yield {"chatbot": gr.update(value=history)}
# Start assistant response
history.append({"role": "assistant", "content": [], "key": str(uuid.uuid4()), "loading": True})
yield {"chatbot": gr.update(value=history)}
try:
# Generate model response
prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"])
response = generate_response(prompt)
# Update assistant message
history[-1]["content"] = [{"type": "text", "content": response}]
history[-1]["loading"] = False
yield {"chatbot": gr.update(value=history)}
except Exception as e:
history[-1]["content"] = [{
"type": "text",
"content": f'<span style="color: red;">{str(e)}</span>'
}]
history[-1]["loading"] = False
yield {"chatbot": gr.update(value=history)}
input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot])
# Load default model on startup
load_model(DEFAULT_SETTINGS["model"])
if __name__ == "__main__":
demo.queue(default_concurrency_limit=10).launch()