Kims12 commited on
Commit
7dcc8af
·
verified ·
1 Parent(s): b34f0d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -90
app.py CHANGED
@@ -1,105 +1,118 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
- import torch
4
 
5
- # 모델 및 토크나이저 로드
6
- model_id = "meta-llama/Llama-3.3-70B-Instruct" # 사용하려는 LLaMA 모델 ID
7
- tokenizer = AutoTokenizer.from_pretrained(model_id)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_id,
10
- torch_dtype=torch.bfloat16,
11
- device_map="auto",
12
- load_in_8bit=False # 메모리 절약을 위해 8-bit 로드 사용 가능
13
- )
 
 
14
 
15
- # 텍스트 생성 파이프라인 설정
16
- text_generator = pipeline(
17
- "text-generation",
18
- model=model,
19
- tokenizer=tokenizer,
20
- device_map="auto",
21
- torch_dtype=torch.bfloat16,
22
- max_length=2048, # 필요에 따라 조정
23
- )
24
 
25
- def generate_response(
26
- user_input,
27
- system_prompt,
28
- max_new_tokens,
 
29
  temperature,
30
- top_p
 
31
  ):
32
- """
33
- 사용자 입력과 옵션을 받아 모델의 응답을 생성하는 함수
34
- """
35
- # 시스템 프롬프트와 사용자 입력을 결합
36
- full_prompt = system_prompt + "\n" + user_input
37
 
38
- # 텍스트 생성
39
- outputs = text_generator(
40
- full_prompt,
41
- max_new_tokens=max_new_tokens,
42
- temperature=temperature,
43
- top_p=top_p,
44
- eos_token_id=tokenizer.eos_token_id,
45
- pad_token_id=tokenizer.eos_token_id,
46
- )
47
 
48
- # 생성된 텍스트 반환
49
- return outputs[0]['generated_text'][len(full_prompt):].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- # Gradio 인터페이스 구성
52
  with gr.Blocks() as demo:
53
- gr.Markdown("# LLaMA 기반 대화형 챗봇")
54
-
 
55
  with gr.Row():
56
- with gr.Column():
57
- system_prompt = gr.Textbox(
58
- label="시스템 프롬프트",
59
- value="You are a helpful assistant.",
60
- lines=2
61
- )
62
- user_input = gr.Textbox(
63
- label="사용자 입력",
64
- placeholder="질문을 입력하세요...",
65
- lines=4
66
  )
67
- with gr.Column():
68
- max_new_tokens = gr.Slider(
69
- label="Max New Tokens",
70
- minimum=16,
71
- maximum=2048,
72
- step=16,
73
- value=256
 
 
 
74
  )
75
- temperature = gr.Slider(
76
- label="Temperature",
77
- minimum=0.1,
78
- maximum=1.0,
79
- step=0.1,
80
- value=0.7
81
- )
82
- top_p = gr.Slider(
83
- label="Top-p (nucleus sampling)",
84
- minimum=0.1,
85
- maximum=1.0,
86
- step=0.1,
87
- value=0.9
88
- )
89
-
90
- generate_button = gr.Button("생성")
91
- output = gr.Textbox(
92
- label="응답",
93
- lines=10
94
- )
95
-
96
- # 버튼 클릭 시 응답 생성
97
- generate_button.click(
98
- fn=generate_response,
99
- inputs=[user_input, system_prompt, max_new_tokens, temperature, top_p],
100
- outputs=output
101
- )
102
 
103
- # Gradio 앱 실행
 
 
 
 
 
 
 
 
 
 
104
  if __name__ == "__main__":
105
- demo.launch()
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+ import os
4
 
5
+ MODELS = {
6
+ "Zephyr 7B Beta": "HuggingFaceH4/zephyr-7b-beta",
7
+ "DeepSeek Coder V2": "deepseek-ai/DeepSeek-Coder-V2-Instruct",
8
+ "Meta Llama 3.1 8B": "meta-llama/Meta-Llama-3.1-8B-Instruct",
9
+ "Meta-Llama 3.1 70B-Instruct": "meta-llama/Meta-Llama-3.1-70B-Instruct",
10
+ "Microsoft": "microsoft/Phi-3-mini-4k-instruct",
11
+ "Mixtral 8x7B": "mistralai/Mistral-7B-Instruct-v0.3",
12
+ "Mixtral Nous-Hermes": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO",
13
+ "Cohere Command R+": "CohereForAI/c4ai-command-r-plus",
14
+ "Aya-23-35B": "CohereForAI/aya-23-35B"
15
+ }
16
 
17
+ def get_client(model_name):
18
+ model_id = MODELS[model_name]
19
+ hf_token = os.getenv("HF_TOKEN")
20
+ if not hf_token:
21
+ raise ValueError("HF_TOKEN environment variable is required")
22
+ return InferenceClient(model_id, token=hf_token)
 
 
 
23
 
24
+ def respond(
25
+ message,
26
+ chat_history,
27
+ model_name,
28
+ max_tokens,
29
  temperature,
30
+ top_p,
31
+ system_message,
32
  ):
33
+ try:
34
+ client = get_client(model_name)
35
+ except ValueError as e:
36
+ chat_history.append((message, str(e)))
37
+ return chat_history
38
 
39
+ messages = [{"role": "system", "content": system_message}]
40
+ for human, assistant in chat_history:
41
+ messages.append({"role": "user", "content": human})
42
+ messages.append({"role": "assistant", "content": assistant})
43
+ messages.append({"role": "user", "content": message})
 
 
 
 
44
 
45
+ try:
46
+ if "Cohere" in model_name:
47
+ # Cohere 모델을 위한 비스트리밍 처리
48
+ response = client.chat_completion(
49
+ messages,
50
+ max_tokens=max_tokens,
51
+ temperature=temperature,
52
+ top_p=top_p,
53
+ )
54
+ assistant_message = response.choices[0].message.content
55
+ chat_history.append((message, assistant_message))
56
+ yield chat_history
57
+ else:
58
+ # 다른 모델들을 위한 스트리밍 처리
59
+ stream = client.chat_completion(
60
+ messages,
61
+ max_tokens=max_tokens,
62
+ temperature=temperature,
63
+ top_p=top_p,
64
+ stream=True,
65
+ )
66
+ partial_message = ""
67
+ for response in stream:
68
+ if response.choices[0].delta.content is not None:
69
+ partial_message += response.choices[0].delta.content
70
+ if len(chat_history) > 0 and chat_history[-1][0] == message:
71
+ chat_history[-1] = (message, partial_message)
72
+ else:
73
+ chat_history.append((message, partial_message))
74
+ yield chat_history
75
+ except Exception as e:
76
+ error_message = f"An error occurred: {str(e)}"
77
+ chat_history.append((message, error_message))
78
+ yield chat_history
79
+
80
+ def clear_conversation():
81
+ return []
82
 
 
83
  with gr.Blocks() as demo:
84
+ gr.Markdown("# Prompting AI Chatbot")
85
+ gr.Markdown("언어모델별 프롬프트 테스트 챗봇입니다.")
86
+
87
  with gr.Row():
88
+ with gr.Column(scale=1):
89
+ model_name = gr.Radio(
90
+ choices=list(MODELS.keys()),
91
+ label="Language Model",
92
+ value="Zephyr 7B Beta"
 
 
 
 
 
93
  )
94
+ max_tokens = gr.Slider(minimum=0, maximum=2000, value=500, step=100, label="Max Tokens")
95
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="Temperature")
96
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p")
97
+ system_message = gr.Textbox(
98
+ value="""반드시 한글로 답변할 것.
99
+ 너는 최고의 비서이다.
100
+ 내가 요구하는것들을 최대한 자세하고 정확하게 답변하라.
101
+ """,
102
+ label="System Message",
103
+ lines=3
104
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ with gr.Column(scale=2):
107
+ chatbot = gr.Chatbot()
108
+ msg = gr.Textbox(label="메세지를 입력하세요")
109
+ with gr.Row():
110
+ submit_button = gr.Button("전송")
111
+ clear_button = gr.Button("대화 내역 지우기")
112
+
113
+ msg.submit(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
114
+ submit_button.click(respond, [msg, chatbot, model_name, max_tokens, temperature, top_p, system_message], chatbot)
115
+ clear_button.click(clear_conversation, outputs=chatbot, queue=False)
116
+
117
  if __name__ == "__main__":
118
+ demo.launch()