File size: 3,450 Bytes
2f4b832
e659cfe
9f7cb9a
2b0dd1e
b3ca2da
aab0c47
720352d
b04ca7b
720352d
 
2f4b832
673bbef
9f7cb9a
3802faf
 
0ff1cd2
3d92619
9ca55ad
 
673bbef
9ca55ad
 
 
0ff1cd2
3802faf
 
 
 
e659cfe
9f7cb9a
 
 
 
 
159c2ce
 
b04ca7b
e659cfe
0ff1cd2
3802faf
 
 
159c2ce
 
720352d
159c2ce
0cb4dc1
3802faf
 
159c2ce
9ca55ad
3802faf
0ff1cd2
9ca55ad
3802faf
159c2ce
9ca55ad
0ff1cd2
 
 
159c2ce
0ff1cd2
159c2ce
0ff1cd2
 
b8261fb
0ff1cd2
9ca55ad
0ff1cd2
159c2ce
0ff1cd2
 
159c2ce
0cb4dc1
 
 
 
 
0ff1cd2
0cb4dc1
0ff1cd2
0cb4dc1
 
 
 
0ff1cd2
0cb4dc1
 
 
3802faf
673bbef
159c2ce
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import spaces
import torch
import subprocess
import sys

# Force install the specific transformers version from the GitHub PR
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])

from transformers import OlmoeForCausalLM, AutoTokenizer

model_name = "allenai/OLMoE-1B-7B-0924"

# Wrap model loading in a try-except block to handle potential errors
try:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model = OlmoeForCausalLM.from_pretrained(
        model_name, 
        trust_remote_code=True, 
        torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
        low_cpu_mem_usage=True,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
    print(f"Error loading model: {e}")
    model = None
    tokenizer = None

system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
                 "who is stuck inside a step function machine and remembers and counts everything he says "
                 "while always answering questions in full first principles analysis type of thinking "
                 "without using any analogies and always showing full working code or output in his answers.")

# Define a chat template as a string
chat_template = "<|system|>{system_message}<|end|><|user|>{user_message}<|end|><|assistant|>"

@spaces.GPU
def generate_response(message, history, temperature, max_new_tokens):
    if model is None or tokenizer is None:
        return "Model or tokenizer not loaded properly. Please check the logs."
    
    # Construct the full prompt
    full_prompt = chat_template.format(system_message=system_prompt, user_message=message)
    
    inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
    
    with torch.no_grad():
        generate_ids = model.generate(
            inputs.input_ids,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            eos_token_id=tokenizer.eos_token_id,
        )
    response = tokenizer.decode(generate_ids[0, inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return response.strip()

css = """
  #output {
    height: 1000px; 
    overflow: auto; 
    border: 2px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE")
    chatbot = gr.Chatbot(elem_id="output")
    msg = gr.Textbox(label="Meow")
    with gr.Row():
        temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
        max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=2000, step=50, label="Max New Tokens")
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history, temp, max_tokens):
        user_message = history[-1][0]
        bot_message = generate_response(user_message, history, temp, max_tokens)
        history[-1][1] = bot_message
        return history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, temperature, max_new_tokens], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue(api_open=True)
    demo.launch(debug=True, show_api=True, share=True)