File size: 3,944 Bytes
829da7c
 
31a1ff8
e2b5fc2
829da7c
 
54fe16b
829da7c
 
 
 
 
 
 
54fe16b
829da7c
 
 
 
 
54fe16b
31a1ff8
54fe16b
829da7c
 
54fe16b
829da7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53b40bf
54fe16b
 
829da7c
4856892
 
829da7c
 
54fe16b
 
4856892
 
54fe16b
829da7c
54fe16b
 
 
 
 
829da7c
 
54fe16b
 
994685c
 
eb95198
 
 
 
994685c
eb95198
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
994685c
54fe16b
53b40bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import os
import spaces

import gradio as gr

import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str)  # model path
    parser.add_argument("--n_gpus", type=int, default=1)  # n_gpu
    return parser.parse_args()

@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
    global model, tokenizer, device
    messages = [{'role': 'system', 'content': system_prompt}]
    for human, assistant in history:
        messages.append({'role': 'user', 'content': human})
        messages.append({'role': 'assistant', 'content': assistant})
    messages.append({'role': 'user', 'content': message})
    problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
    stop_tokens = ["<|endoftext|>", "<|im_end|>"]
    streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
    input_ids = enc.input_ids
    attention_mask = enc.attention_mask

    if input_ids.shape[1] > MAX_LENGTH:
        input_ids = input_ids[:, -MAX_LENGTH:]

    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    generate_kwargs = dict(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        streamer=streamer,
        do_sample=True,
        top_p=0.95,
        temperature=temperature,
        max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
        use_cache=True,
        eos_token_id=100278 # <|im_end|>
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)



if __name__ == "__main__":
    args = parse_args()
    tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-12b-chat", trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-12b-chat", trust_remote_code=True, torch_dtype=torch.bfloat16)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    gr.ChatInterface(
        predict,
        title="StableLM 2 12B Chat - Demo",
        description="StableLM 2 12B Chat - StabilityAI",
        theme="soft",
        chatbot=gr.Chatbot(label="Chat History",),
        textbox=gr.Textbox(placeholder="input", container=False, scale=7),
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear",
        additional_inputs=[
            gr.Textbox("You are a helpful assistant.", label="System Prompt"),
            gr.Slider(0, 1, 0.5, label="Temperature"),
            gr.Slider(100, 2048, 1024, label="Max Tokens"),
        ],
        examples=[
            ["Implement snake game using pygame"],
            ["Escribe un poema corto sobre la historia del Mediterráneo."],
            ["Scrivi un Haiku che celebri il gelato"],
            ["Schreibe ein Haiku über die Alpen."],
            ["Ecris une prose a propos de la mer du Nord."],
            ["Escreva um poema sobre a saudade."],
            ["""\
This is the error that is thrown for the given python code.
```python
import pandas as pd
df = pd.read_csv(PATH)
This throws an error,
```
```shell
pandas/io/common.py"", line 873, in get_handle
    handle = open(
             ^^^^^
FileNotFoundError: [Errno 2] No such file or directory: 'data.csv'```
How to solve this ?
"""
            ],
            ["Jane has 8 apples, out of which 2 are red and 3 are green. How many of them are white? Assuming there are only 3 colors, can you solve this with python?"],
        ],
        additional_inputs_accordion_name="Parameters",
    ).queue().launch()