File size: 3,573 Bytes
2f4b832
e659cfe
9f7cb9a
2b0dd1e
b3ca2da
aab0c47
e203e91
b04ca7b
e203e91
720352d
 
2f4b832
673bbef
9f7cb9a
3802faf
 
0ff1cd2
3d92619
9ca55ad
 
673bbef
9ca55ad
e203e91
 
9ca55ad
0ff1cd2
3802faf
 
 
 
e659cfe
9f7cb9a
 
 
 
 
159c2ce
b04ca7b
e659cfe
0ff1cd2
3802faf
e203e91
 
 
159c2ce
 
0cb4dc1
e203e91
 
 
 
 
 
 
0ff1cd2
 
 
159c2ce
0ff1cd2
159c2ce
0ff1cd2
 
b8261fb
0ff1cd2
e203e91
0ff1cd2
159c2ce
0ff1cd2
 
159c2ce
0cb4dc1
 
 
 
 
0ff1cd2
0cb4dc1
e203e91
 
 
 
 
0cb4dc1
 
0ff1cd2
0cb4dc1
e203e91
0cb4dc1
3802faf
673bbef
e203e91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import spaces
import torch
import subprocess
import sys

# Install required packages
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
subprocess.run('pip install flash-attn --no-build-isolation --no-deps', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

from transformers import OlmoeForCausalLM, AutoTokenizer

model_name = "allenai/OLMoE-1B-7B-0924"

# Wrap model loading in a try-except block to handle potential errors
try:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model = OlmoeForCausalLM.from_pretrained(
        model_name, 
        trust_remote_code=True, 
        torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32,
        low_cpu_mem_usage=True,
        device_map="auto",
        _attn_implementation="flash_attention_2"  # Enable Flash Attention 2
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
    print(f"Error loading model: {e}")
    model = None
    tokenizer = None

system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
                 "who is stuck inside a step function machine and remembers and counts everything he says "
                 "while always answering questions in full first principles analysis type of thinking "
                 "without using any analogies and always showing full working code or output in his answers.")

chat_template = "<|system|>{system_message}<|end|><|user|>{user_message}<|end|><|assistant|>"

@spaces.GPU
def generate_response(message, history, temperature, max_new_tokens):
    if model is None or tokenizer is None:
        yield "Model or tokenizer not loaded properly. Please check the logs."
        return

    full_prompt = chat_template.format(system_message=system_prompt, user_message=message)
    inputs = tokenizer(full_prompt, return_tensors="pt").to(DEVICE)
    
    streamer = tokenizer.stream(inputs.input_ids, model, temperature=temperature, max_new_tokens=max_new_tokens)
    
    collected_tokens = []
    for token in streamer:
        collected_tokens.append(token)
        partial_text = tokenizer.decode(collected_tokens, skip_special_tokens=True)
        yield partial_text.strip()

css = """
  #output {
    height: 1000px; 
    overflow: auto; 
    border: 2px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Nisten's Karpathy Chatbot with OSS OLMoE (Now with Flash Attention 2!)")
    chatbot = gr.Chatbot(elem_id="output")
    msg = gr.Textbox(label="Meow")
    with gr.Row():
        temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
        max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=2000, step=50, label="Max New Tokens")
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [[user_message, None]]

    def bot(history, temp, max_tokens):
        user_message = history[-1][0]
        bot_message = ""
        for token in generate_response(user_message, history, temp, max_tokens):
            bot_message = token
            history[-1][1] = bot_message
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, temperature, max_new_tokens], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=True)

if __name__ == "__main__":
    demo.queue(api_open=True)
    demo.launch(debug=True, show_api=True)