File size: 4,303 Bytes
2f4b832
e659cfe
9f7cb9a
2b0dd1e
b3ca2da
aab0c47
e203e91
a622fef
720352d
a622fef
 
2f4b832
32720ee
9f7cb9a
3802faf
 
0ff1cd2
3d92619
9ca55ad
 
c720fed
9ca55ad
e203e91
32720ee
c720fed
0ff1cd2
3802faf
 
 
 
e659cfe
9f7cb9a
 
 
 
 
e659cfe
24b2580
3802faf
24b2580
e203e91
 
32720ee
c720fed
 
24b2580
32720ee
 
0cb4dc1
e9acdad
a622fef
deeaafe
 
 
 
 
 
5598c41
deeaafe
 
a622fef
deeaafe
 
5598c41
deeaafe
5598c41
 
 
 
 
 
 
 
 
e9acdad
5598c41
0ff1cd2
 
 
023bf24
0ff1cd2
023bf24
0ff1cd2
 
b8261fb
0ff1cd2
c720fed
0ff1cd2
159c2ce
0ff1cd2
 
5598c41
0cb4dc1
 
24b2580
c720fed
24b2580
 
c720fed
24b2580
 
 
c720fed
24b2580
0cb4dc1
 
0ff1cd2
0cb4dc1
e9acdad
0cb4dc1
3802faf
24b2580
c720fed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import spaces
import torch
import subprocess
import sys

# Install required packages
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])

from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

model_name = "allenai/OLMoE-1B-7B-0924-Instruct"

# Wrap model loading in a try-except block to handle potential errors
try:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    model = OlmoeForCausalLM.from_pretrained(
        model_name, 
        trust_remote_code=True, 
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map="auto",
    ).to(DEVICE)
    model.gradient_checkpointing_enable()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception as e:
    print(f"Error loading model: {e}")
    model = None
    tokenizer = None

system_prompt = ("Adopt the persona of hilariously pissed off Andrej Karpathy "
                 "who is stuck inside a step function machine and remembers and counts everything he says "
                 "while always answering questions in full first principles analysis type of thinking "
                 "without using any analogies and always showing full working code or output in his answers.")

@spaces.GPU
def generate_response(message, history, temperature, max_new_tokens):
    if model is None or tokenizer is None:
        yield "Model or tokenizer not loaded properly. Please check the logs."
        return

    messages = [{"role": "system", "content": system_prompt}]
    for msg in history:
        messages.append({"role": "user" if msg["role"] == "human" else "assistant", "content": msg["content"]})
    messages.append({"role": "user", "content": message})

    inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
    
    try:
        streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
        generation_kwargs = dict(
            inputs=inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            eos_token_id=tokenizer.eos_token_id,
            streamer=streamer
        )

        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        generated_text = ""
        for new_text in streamer:
            generated_text += new_text
            yield generated_text.strip()

        thread.join()
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            yield "GPU memory exceeded. Try reducing the max tokens or using a smaller model."
        else:
            yield f"An error occurred: {str(e)}"
    except Exception as e:
        yield f"An unexpected error occurred: {str(e)}"

css = """
  #output {
    height: 1000px; 
    overflow: auto; 
    border: 2px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# Nisten's Karpathy Chatbot with OLMoE (CPU only instance feel free to clone!)")
    chatbot = gr.Chatbot(elem_id="output")
    msg = gr.Textbox(label="Meow")
    with gr.Row():
        temperature = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
        max_new_tokens = gr.Slider(minimum=50, maximum=4000, value=2000, step=50, label="Max New Tokens")
    clear = gr.Button("Clear")

    def user(user_message, history):
        return "", history + [{"role": "human", "content": user_message}]

    def bot(history, temp, max_tokens):
        user_message = history[-1]["content"]
        bot_message = ""
        for token in generate_response(user_message, history[:-1], temp, max_tokens):
            bot_message = token
            history.append({"role": "assistant", "content": bot_message})
            yield history

    msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
        bot, [chatbot, temperature, max_new_tokens], chatbot
    )
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.queue(api_open=True, max_size=10)  # Limiting queue size
    demo.launch(debug=True, show_api=True, share=False)