File size: 1,787 Bytes
3c321fe
e3fff45
 
3c321fe
e3fff45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2692d93
e3fff45
2692d93
e3fff45
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "kakaocorp/kanana-nano-2.1b-instruct"

# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ €๋ฅผ CPU ํ™˜๊ฒฝ์—์„œ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float32,  # CPU์—์„œ๋Š” bfloat16 ์ง€์›์ด ์ œํ•œ๋  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ float32 ์‚ฌ์šฉ ๊ถŒ์žฅ
    trust_remote_code=True,
)
# CPU๋งŒ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ .to("cuda") ๋ถ€๋ถ„์€ ์ƒ๋žตํ•ฉ๋‹ˆ๋‹ค.
tokenizer = AutoTokenizer.from_pretrained(model_name)

def generate_response(prompt):
    messages = [
        {"role": "system", "content": "You are a helpful AI assistant developed by Kakao."},
        {"role": "user", "content": prompt}
    ]
    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    )
    
    model.eval()
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=72,
            do_sample=False,
        )
    return tokenizer.decode(output[0], skip_special_tokens=True)

with gr.Blocks() as demo:
    with gr.Tab("About"):
        gr.Markdown("# Inference Provider")
        gr.Markdown("์ด Space๋Š” kakaocorp/kanana-nano-2.1b-instruct ๋ชจ๋ธ์„ CPU์—์„œ ์ถ”๋ก ํ•ฉ๋‹ˆ๋‹ค.")
    
    with gr.Tab("Generate"):
        prompt_input = gr.Textbox(
            label="Prompt ์ž…๋ ฅ",
            placeholder="์—ฌ๊ธฐ์— ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...",
            lines=5
        )
        generate_btn = gr.Button("์ƒ์„ฑ")
        output_text = gr.Textbox(
            label="๋ชจ๋ธ ์ถœ๋ ฅ",
            lines=10
        )
        generate_btn.click(fn=generate_response, inputs=prompt_input, outputs=output_text)

demo.launch()