File size: 4,999 Bytes
36942d4
d91c9af
 
 
 
852d26e
d91c9af
5a00fde
d91c9af
 
cee13f4
5a00fde
f0687e5
 
 
 
 
 
5a00fde
f0687e5
 
 
 
5a00fde
f0687e5
 
 
9be0b0d
f0687e5
 
 
 
9be0b0d
 
 
f0687e5
 
 
d91c9af
 
 
 
 
341bd22
5a00fde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d91c9af
 
 
 
9be0b0d
d91c9af
 
 
 
 
 
 
 
 
 
 
5a00fde
 
 
 
 
 
 
 
d91c9af
 
 
 
 
5a00fde
 
 
0826c27
5a00fde
d91c9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a00fde
 
 
 
 
 
 
d91c9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341bd22
 
d91c9af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import uuid
import time
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import modelscope_studio.components.antd as antd
import modelscope_studio.components.antdx as antdx
import modelscope_studio.components.base as ms
import modelscope_studio.components.pro as pro

# Define model paths
MODEL_PATHS = {
    "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
    "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
    "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}

# Set HF token
hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not hf_token:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")

# Load tokenizer and model globally
tokenizer = None
model = None

def load_model(model_name: str):
    global tokenizer, model
    if model_name not in MODEL_PATHS:
        raise ValueError(f"Unknown model: {model_name}")
    print(f"Loading {model_name}...")
    repo = MODEL_PATHS[model_name]
    tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token)
    model.eval()
    print(f"{model_name} loaded.")

def generate_response(prompt, max_new_tokens=200):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response[len(prompt):].strip()

# CSS for styling chatbot header with avatar
css = """
.chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header {
  display: flex;
  align-items: center;
}
.chatbot-chat-messages .ant-pro-chat-message .ant-pro-chat-message-header img {
  width: 20px;
  height: 20px;
  margin-right: 8px;
  vertical-align: middle;
}
"""

# Default settings
DEFAULT_SETTINGS = {
    "model": "LeCarnet-3M",
    "sys_prompt": "",
}

# Initial state with one fixed conversation
state = gr.State({
    "conversation_id": "default",
    "conversation_contexts": {
        "default": {
            "history": [],
            "settings": DEFAULT_SETTINGS,
        }
    },
})

# Welcome message (optional)
def welcome_config():
    return {
        "title": "LeCarnet Chatbot",
        "description": "Start chatting below!",
        "promptSuggestions": ["Hello", "Tell me a story", "How are you?"]
    }

with gr.Blocks(css=css) as demo:
    with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"):
        # Right Column - Chat Interface
        with antd.Col(flex=1, elem_style=dict(height="100%")):
            with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"):
                chatbot = pro.Chatbot(
                    elem_classes="chatbot-chat-messages",
                    height=0,
                    welcome_config=welcome_config()
                )
                with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion:
                    with ms.Slot("children"):
                        input = antdx.Sender(placeholder="Type your message here...")

        current_state = state

        def add_message(user_input, state_value):
            history = state_value["conversation_contexts"]["default"]["history"]
            settings = state_value["conversation_contexts"]["default"]["settings"]
            selected_model = settings["model"]

            # Add user message
            history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())})
            yield {"chatbot": gr.update(value=history)}

            # Start assistant response
            history.append({
                "role": "assistant",
                "content": [],
                "key": str(uuid.uuid4()),
                "header": f'<img src="/file=media/le-carnet.png" style="width:20px;height:20px;margin-right:8px;"> <span>{selected_model}</span>',
                "loading": True
            })
            yield {"chatbot": gr.update(value=history)}

            try:
                # Generate model response
                prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"])
                response = generate_response(prompt)

                # Update assistant message
                history[-1]["content"] = [{"type": "text", "content": response}]
                history[-1]["loading"] = False
                yield {"chatbot": gr.update(value=history)}
            except Exception as e:
                history[-1]["content"] = [{
                    "type": "text",
                    "content": f'<span style="color: red;">{str(e)}</span>'
                }]
                history[-1]["loading"] = False
                yield {"chatbot": gr.update(value=history)}

        input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot])

# Load default model on startup
load_model(DEFAULT_SETTINGS["model"])

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10).launch()