File size: 3,978 Bytes
36942d4
d91c9af
 
 
 
852d26e
d91c9af
 
 
cee13f4
f0687e5
 
 
 
 
 
 
 
 
 
 
 
 
9be0b0d
f0687e5
 
 
 
9be0b0d
 
 
f0687e5
 
 
d91c9af
 
 
 
 
341bd22
d91c9af
 
 
 
9be0b0d
d91c9af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341bd22
 
d91c9af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import uuid
import time
import json
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import modelscope_studio.components.antd as antd
import modelscope_studio.components.base as ms
import modelscope_studio.components.pro as pro

MODEL_PATHS = {
    "LeCarnet-3M": "MaxLSB/LeCarnet-3M",
    "LeCarnet-8M": "MaxLSB/LeCarnet-8M",
    "LeCarnet-21M": "MaxLSB/LeCarnet-21M",
}

hf_token = os.environ.get("HUGGINGFACEHUB_API_TOKEN")
if not hf_token:
    raise ValueError("HUGGINGFACEHUB_API_TOKEN environment variable not set.")

tokenizer = None
model = None

def load_model(model_name: str):
    global tokenizer, model
    if model_name not in MODEL_PATHS:
        raise ValueError(f"Unknown model: {model_name}")
    print(f"Loading {model_name}...")
    repo = MODEL_PATHS[model_name]
    tokenizer = AutoTokenizer.from_pretrained(repo, use_auth_token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(repo, use_auth_token=hf_token)
    model.eval()
    print(f"{model_name} loaded.")

def generate_response(prompt, max_new_tokens=200):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response[len(prompt):].strip()

DEFAULT_SETTINGS = {
    "model": "LeCarnet-3M",
    "sys_prompt": "",
}

# Initial state with one fixed conversation
state = gr.State({
    "conversation_id": "default",
    "conversation_contexts": {
        "default": {
            "history": [],
            "settings": DEFAULT_SETTINGS,
        }
    },
})

with gr.Blocks(css=css) as demo:
    with ms.Application(), antd.Row(gutter=[20, 20], wrap=False, elem_id="chatbot"):
        # Right Column - Chat Interface
        with antd.Col(flex=1, elem_style=dict(height="100%")):
            with antd.Flex(vertical=True, gap="small", elem_classes="chatbot-chat"):
                chatbot = pro.Chatbot(elem_classes="chatbot-chat-messages", height=0)
                with antdx.Suggestion(items=["Hello", "How are you?", "Tell me something"]) as suggestion:
                    with ms.Slot("children"):
                        input = antdx.Sender(placeholder="Type your message here...")

        # Internal State
        current_state = state

        def add_message(user_input, state_value):
            history = state_value["conversation_contexts"]["default"]["history"]
            settings = state_value["conversation_contexts"]["default"]["settings"]
            selected_model = settings["model"]

            # Add user message
            history.append({"role": "user", "content": user_input, "key": str(uuid.uuid4())})
            yield {"chatbot": gr.update(value=history)}

            # Start assistant response
            history.append({"role": "assistant", "content": [], "key": str(uuid.uuid4()), "loading": True})
            yield {"chatbot": gr.update(value=history)}

            try:
                # Generate model response
                prompt = "\n".join([msg["content"] for msg in history if msg["role"] == "user"])
                response = generate_response(prompt)

                # Update assistant message
                history[-1]["content"] = [{"type": "text", "content": response}]
                history[-1]["loading"] = False
                yield {"chatbot": gr.update(value=history)}
            except Exception as e:
                history[-1]["content"] = [{
                    "type": "text",
                    "content": f'<span style="color: red;">{str(e)}</span>'
                }]
                history[-1]["loading"] = False
                yield {"chatbot": gr.update(value=history)}

        input.submit(fn=add_message, inputs=[input, state], outputs=[chatbot])

# Load default model on startup
load_model(DEFAULT_SETTINGS["model"])

if __name__ == "__main__":
    demo.queue(default_concurrency_limit=10).launch()