Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import torch | |
# 모델 및 토크나이저 로드 | |
model_id = "meta-llama/Llama-3.3-70B-Instruct" # 사용하려는 LLaMA 모델 ID | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
load_in_8bit=False # 메모리 절약을 위해 8-bit 로드 사용 가능 | |
) | |
# 텍스트 생성 파이프라인 설정 | |
text_generator = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
max_length=2048, # 필요에 따라 조정 | |
) | |
def generate_response( | |
user_input, | |
system_prompt, | |
max_new_tokens, | |
temperature, | |
top_p | |
): | |
""" | |
사용자 입력과 옵션을 받아 모델의 응답을 생성하는 함수 | |
""" | |
# 시스템 프롬프트와 사용자 입력을 결합 | |
full_prompt = system_prompt + "\n" + user_input | |
# 텍스트 생성 | |
outputs = text_generator( | |
full_prompt, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
# 생성된 텍스트 반환 | |
return outputs[0]['generated_text'][len(full_prompt):].strip() | |
# Gradio 인터페이스 구성 | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLaMA 기반 대화형 챗봇") | |
with gr.Row(): | |
with gr.Column(): | |
system_prompt = gr.Textbox( | |
label="시스템 프롬프트", | |
value="You are a helpful assistant.", | |
lines=2 | |
) | |
user_input = gr.Textbox( | |
label="사용자 입력", | |
placeholder="질문을 입력하세요...", | |
lines=4 | |
) | |
with gr.Column(): | |
max_new_tokens = gr.Slider( | |
label="Max New Tokens", | |
minimum=16, | |
maximum=2048, | |
step=16, | |
value=256 | |
) | |
temperature = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=0.7 | |
) | |
top_p = gr.Slider( | |
label="Top-p (nucleus sampling)", | |
minimum=0.1, | |
maximum=1.0, | |
step=0.1, | |
value=0.9 | |
) | |
generate_button = gr.Button("생성") | |
output = gr.Textbox( | |
label="응답", | |
lines=10 | |
) | |
# 버튼 클릭 시 응답 생성 | |
generate_button.click( | |
fn=generate_response, | |
inputs=[user_input, system_prompt, max_new_tokens, temperature, top_p], | |
outputs=output | |
) | |
# Gradio 앱 실행 | |
if __name__ == "__main__": | |
demo.launch() | |