Spaces:
Sleeping
Sleeping
import os | |
import uuid | |
import time | |
import json | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import modelscope_studio.components.antd as antd | |
import modelscope_studio.components.base as ms | |
import modelscope_studio.components.pro as pro | |
MODEL_PATHS = { | |
"LeCarnet-3M": "MaxLSB/LeCarnet-3M", | |
"LeCarnet-8M": "MaxLSB/LeCarnet-8M", | |
"LeCarnet-21M": "MaxLSB/LeCarnet-21M", | |
} | |
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
if not hf_token: | |
raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.") | |
tokenizer = None | |
model = None | |
def load_model(model_name: str): | |
global tokenizer, model | |
if model_name not in MODEL_PATHS: | |
raise ValueError(f"Unknown model: {model_name}") | |
print(f"Loading {model_name}...") | |
repo = MODEL_PATHS[model_name] | |
tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token) | |
model.eval() | |
print(f"{model_name} loaded.") | |
def generate_response(prompt, max_new_tokens=200): | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response[len(prompt):].strip() | |
DEFAULT_SETTINGS = { | |
"model": "LeCarnet-3M", | |
"sys_prompt": "", | |
} | |
# Initial state with one fixed conversation | |
state = gr.State({ | |
"conversation_id": "default", | |
"conversation_contexts": { | |
"default": { | |
"history": [], | |
"settings": DEFAULT_SETTINGS, | |
} | |
}, | |
}) | |
with gr.Blocks(css=css) as demo: | |
with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"): | |
# Right Column - Chat Interface | |
with antd.Col(flex=1, elem_style=dict(height="100%")): | |
with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"): | |
chatbot = pro.Chatbot(elem_classes="chatbot-chat-messages", height=0) | |
with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion: | |
with ms.Slot("children"): | |
input = antdx.Sender(placeholder="Type your message here...") | |
# Internal State | |
current_state = state | |
def add_message(user_input, state_value): | |
history = state_value["conversation_contexts"]["default"]["history"] | |
settings = state_value["conversation_contexts"]["default"]["settings"] | |
selected_model = settings["model"] | |
# Add user message | |
history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())}) | |
yield {"chatbot": gr.update(value=history)} | |
# Start assistant response | |
history.append({"role": "assistant", "content": [], "key": str(uuid.uuid4()), "loading": True}) | |
yield {"chatbot": gr.update(value=history)} | |
try: | |
# Generate model response | |
prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"]) | |
response = generate_response(prompt) | |
# Update assistant message | |
history[-1]["content"] = [{"type": "text", "content": response}] | |
history[-1]["loading"] = False | |
yield {"chatbot": gr.update(value=history)} | |
except Exception as e: | |
history[-1]["content"] = [{ | |
"type": "text", | |
"content": f'<span style="color: red;">{str(e)}</span>' | |
}] | |
history[-1]["loading"] = False | |
yield {"chatbot": gr.update(value=history)} | |
input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot]) | |
# Load default model on startup | |
load_model(DEFAULT_SETTINGS["model"]) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=10).launch() |