File size: 3,560 Bytes
b37c16f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import json
import os
import time

import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "0"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


def get_gpu_memory():
    return torch.cuda.memory_allocated() / 1024 / 1024  # Convert to MiB


class TorchTracemalloc:
    def __init__(self):
        self.begin = 0
        self.peak = 0

    def __enter__(self):
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()
        self.begin = get_gpu_memory()
        return self

    def __exit__(self, *exc):
        torch.cuda.synchronize()
        self.peak = (
            torch.cuda.max_memory_allocated() / 1024 / 1024
        )  # Convert to MiB

    def consumed(self):
        return self.peak - self.begin


def load_model_and_tokenizer():
    model_name = "NousResearch/Hermes-2-Theta-Llama-3-8B"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    special_tokens = {"pad_token": "<PAD>"}
    tokenizer.add_special_tokens(special_tokens)
    config = AutoConfig.from_pretrained(model_name)
    setattr(
        config,
        "quantizer_path",
        f"codebooks/Hermes-2-Theta-Llama-3-8B_1bit.xmad",
    )
    setattr(config, "window_length", 32)
    model = AutoModelForCausalLM.from_pretrained(
        model_name, config=config, torch_dtype=torch.float16, device_map="cuda:2"
    )
    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
        print(
            "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
        )
        model.resize_token_embeddings(len(tokenizer))
    model.config.pad_token_id = tokenizer.pad_token_id
    return model, tokenizer


def process_dialog(dialog, model, tokenizer):
    prompt = tokenizer.apply_chat_template(
        dialog, tokenize=False, add_generation_prompt=True
    )
    tokenized_input_prompt_ids = tokenizer(
        prompt, return_tensors="pt"
    ).input_ids.to(model.device)

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    with TorchTracemalloc() as tt:
        start_time = time.time()
        with torch.no_grad():
            token_ids_for_each_answer = model.generate(
                tokenized_input_prompt_ids,
                max_new_tokens=512,
                temperature=0.7,
                do_sample=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
            )
        torch.cuda.synchronize()
        end_time = time.time()

    response = token_ids_for_each_answer[0][
        tokenized_input_prompt_ids.shape[-1] :
    ]
    cleaned_response = tokenizer.decode(
        response,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )

    return cleaned_response


model, tokenizer = load_model_and_tokenizer()


def chatbot_interface(user_input, chat_history):
    dialog = [{"role": "user", "content": user_input}]
    response = process_dialog(dialog, model, tokenizer)
    chat_history.append((user_input, response))
    return chat_history, chat_history


def main():
    with gr.Blocks() as demo:
        chatbot = gr.Chatbot()
        user_input = gr.Textbox(placeholder="Type your message here...")
        clear = gr.Button("Clear")

        user_input.submit(chatbot_interface, [user_input, chatbot], [chatbot, chatbot])
        clear.click(lambda: None, None, chatbot)

    demo.launch()


if __name__ == "__main__":
    main()