AIRider's picture
Update app.py
d0fd9d3 verified
raw
history blame
4.45 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
from threading import Event
hf_token = os.getenv("HF_TOKEN")
stop_event = Event()
# 모델 목록 정의
models = {
"deepseek-ai/DeepSeek-Coder-V2-Instruct": "(한국회사)DeepSeek-Coder-V2-Instruct",
"meta-llama/Meta-Llama-3.1-8B-Instruct": "Meta-Llama-3.1-8B-Instruct",
"mistralai/Mixtral-8x7B-Instruct-v0.1": "Mixtral-8x7B-Instruct-v0.1",
"CohereForAI/c4ai-command-r-plus": "Cohere Command-R Plus"
}
# Inference 클라이언트 반환
def get_client(model):
return InferenceClient(model=model, token=hf_token)
# 응답 생성 함수 (스트리밍 방식, 자문자답 방지)
def respond(prompt, system_message, max_tokens, temperature, top_p, selected_model):
stop_event.clear()
client = get_client(selected_model)
# 프롬프트 설정
messages = [
{"role": "system", "content": system_message + "\n주어진 입력에 대해서만 정확히 답변하세요. 추가 질문을 생성하거나 입력 내용을 확장하지 마세요."},
{"role": "user", "content": prompt}
]
try:
response = ""
# 모델에서 응답을 스트리밍
for chunk in client.text_generation(
prompt="\n".join([f"{m['role']}: {m['content']}" for m in messages]),
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
stop_sequences=["Human:", "User:"] # 자문자답 방지를 위한 정지 시퀀스 추가
):
if stop_event.is_set():
break
if chunk:
response += chunk
if "?" in chunk: # 질문 마크가 포함된 경우 생성 중단
break
yield [(prompt, response.strip())] # 실시간으로 부분적인 응답 반환
except Exception as e:
yield [(prompt, f"오류 발생: {str(e)}")]
# 응답 중단 함수
def stop_generation():
stop_event.set()
# Gradio UI 구성
with gr.Blocks() as demo:
gr.Markdown("# 프롬프트 최적화 Playground")
gr.Markdown("""
**주의사항:**
- '전송' 버튼을 클릭하거나 입력 필드에서 Shift+Enter를 눌러 메시지를 전송할 수 있습니다.
- Enter 키는 줄바꿈으로 작동합니다.
- 입력한 내용에 대해서만 응답하도록 설정되어 있으며, 모델이 추가 질문을 만들거나 입력을 확장하지 않도록 설정됩니다.
""")
with gr.Row():
with gr.Column(scale=1):
with gr.Accordion("모델 설정", open=True):
model = gr.Radio(list(models.keys()), value=list(models.keys())[0], label="언어 모델 선택", info="사용할 언어 모델을 선택하세요")
max_tokens = gr.Slider(minimum=1, maximum=2000, value=500, step=100, label="최대 새 토큰 수")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.05, label="온도")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.90, step=0.05, label="Top-p (핵 샘플링)")
system_message = gr.Textbox(
value="당신은 정확하고 간결한 응답을 제공하는 AI 어시스턴트입니다. 사용자의 입력에 대해서만 답변하고, 추가 질문이나 확장된 대화를 생성하지 마세요.",
label="시스템 메시지",
lines=5
)
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=400, label="대화 결과")
prompt = gr.Textbox(
label="내용 입력",
lines=3,
placeholder="메시지를 입력하세요. 전송 버튼을 클릭하거나 Shift+Enter를 눌러 전송합니다."
)
with gr.Row():
send = gr.Button("전송")
stop = gr.Button("🛑 생성 중단")
clear = gr.Button("🗑️ 대화 내역 지우기")
# Event handlers
send.click(respond, inputs=[prompt, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
prompt.submit(respond, inputs=[prompt, system_message, max_tokens, temperature, top_p, model], outputs=[chatbot])
stop.click(stop_generation)
clear.click(lambda: None, outputs=[chatbot])
# UI 실행
demo.launch()