Kims12 commited on
Commit
b34f0d5
·
verified ·
1 Parent(s): 2dc775e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -0
app.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+
5
+ # 모델 및 토크나이저 로드
6
+ model_id = "meta-llama/Llama-3.3-70B-Instruct" # 사용하려는 LLaMA 모델 ID
7
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_id,
10
+ torch_dtype=torch.bfloat16,
11
+ device_map="auto",
12
+ load_in_8bit=False # 메모리 절약을 위해 8-bit 로드 사용 가능
13
+ )
14
+
15
+ # 텍스트 생성 파이프라인 설정
16
+ text_generator = pipeline(
17
+ "text-generation",
18
+ model=model,
19
+ tokenizer=tokenizer,
20
+ device_map="auto",
21
+ torch_dtype=torch.bfloat16,
22
+ max_length=2048, # 필요에 따라 조정
23
+ )
24
+
25
+ def generate_response(
26
+ user_input,
27
+ system_prompt,
28
+ max_new_tokens,
29
+ temperature,
30
+ top_p
31
+ ):
32
+ """
33
+ 사용자 입력과 옵션을 받아 모델의 응답을 생성하는 함수
34
+ """
35
+ # 시스템 프롬프트와 사용자 입력을 결합
36
+ full_prompt = system_prompt + "\n" + user_input
37
+
38
+ # 텍스트 생성
39
+ outputs = text_generator(
40
+ full_prompt,
41
+ max_new_tokens=max_new_tokens,
42
+ temperature=temperature,
43
+ top_p=top_p,
44
+ eos_token_id=tokenizer.eos_token_id,
45
+ pad_token_id=tokenizer.eos_token_id,
46
+ )
47
+
48
+ # 생성된 텍스트 반환
49
+ return outputs[0]['generated_text'][len(full_prompt):].strip()
50
+
51
+ # Gradio 인터페이스 구성
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("# LLaMA 기반 대화형 챗봇")
54
+
55
+ with gr.Row():
56
+ with gr.Column():
57
+ system_prompt = gr.Textbox(
58
+ label="시스템 프롬프트",
59
+ value="You are a helpful assistant.",
60
+ lines=2
61
+ )
62
+ user_input = gr.Textbox(
63
+ label="사용자 입력",
64
+ placeholder="질문을 입력하세요...",
65
+ lines=4
66
+ )
67
+ with gr.Column():
68
+ max_new_tokens = gr.Slider(
69
+ label="Max New Tokens",
70
+ minimum=16,
71
+ maximum=2048,
72
+ step=16,
73
+ value=256
74
+ )
75
+ temperature = gr.Slider(
76
+ label="Temperature",
77
+ minimum=0.1,
78
+ maximum=1.0,
79
+ step=0.1,
80
+ value=0.7
81
+ )
82
+ top_p = gr.Slider(
83
+ label="Top-p (nucleus sampling)",
84
+ minimum=0.1,
85
+ maximum=1.0,
86
+ step=0.1,
87
+ value=0.9
88
+ )
89
+
90
+ generate_button = gr.Button("생성")
91
+ output = gr.Textbox(
92
+ label="응답",
93
+ lines=10
94
+ )
95
+
96
+ # 버튼 클릭 시 응답 생성
97
+ generate_button.click(
98
+ fn=generate_response,
99
+ inputs=[user_input, system_prompt, max_new_tokens, temperature, top_p],
100
+ outputs=output
101
+ )
102
+
103
+ # Gradio 앱 실행
104
+ if __name__ == "__main__":
105
+ demo.launch()