File size: 2,937 Bytes
b34f0d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
import torch

# 모델 및 토크나이저 로드
model_id = "meta-llama/Llama-3.3-70B-Instruct"  # 사용하려는 LLaMA 모델 ID
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    load_in_8bit=False  # 메모리 절약을 위해 8-bit 로드 사용 가능
)

# 텍스트 생성 파이프라인 설정
text_generator = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    max_length=2048,  # 필요에 따라 조정
)

def generate_response(
    user_input,
    system_prompt,
    max_new_tokens,
    temperature,
    top_p
):
    """
    사용자 입력과 옵션을 받아 모델의 응답을 생성하는 함수
    """
    # 시스템 프롬프트와 사용자 입력을 결합
    full_prompt = system_prompt + "\n" + user_input

    # 텍스트 생성
    outputs = text_generator(
        full_prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    # 생성된 텍스트 반환
    return outputs[0]['generated_text'][len(full_prompt):].strip()

# Gradio 인터페이스 구성
with gr.Blocks() as demo:
    gr.Markdown("# LLaMA 기반 대화형 챗봇")
    
    with gr.Row():
        with gr.Column():
            system_prompt = gr.Textbox(
                label="시스템 프롬프트",
                value="You are a helpful assistant.",
                lines=2
            )
            user_input = gr.Textbox(
                label="사용자 입력",
                placeholder="질문을 입력하세요...",
                lines=4
            )
        with gr.Column():
            max_new_tokens = gr.Slider(
                label="Max New Tokens",
                minimum=16,
                maximum=2048,
                step=16,
                value=256
            )
            temperature = gr.Slider(
                label="Temperature",
                minimum=0.1,
                maximum=1.0,
                step=0.1,
                value=0.7
            )
            top_p = gr.Slider(
                label="Top-p (nucleus sampling)",
                minimum=0.1,
                maximum=1.0,
                step=0.1,
                value=0.9
            )
    
    generate_button = gr.Button("생성")
    output = gr.Textbox(
        label="응답",
        lines=10
    )
    
    # 버튼 클릭 시 응답 생성
    generate_button.click(
        fn=generate_response,
        inputs=[user_input, system_prompt, max_new_tokens, temperature, top_p],
        outputs=output
    )

# Gradio 앱 실행
if __name__ == "__main__":
    demo.launch()