command-r-plus / app.py
jameshhjung's picture
Update app.py
0f57722 verified
raw
history blame
4.82 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import gc
class ModelManager:
def __init__(self):
self.model = None
self.tokenizer = None
self.model_name = "CohereForAI/c4ai-command-r-plus-4bit"
def load_model(self):
if self.model is None:
try:
print("λͺ¨λΈ λ‘œλ”© 쀑... μ‹œκ°„μ΄ 걸릴 수 μžˆμŠ΅λ‹ˆλ‹€.")
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
load_in_4bit=True,
low_cpu_mem_usage=True
)
print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ!")
return True
except Exception as e:
print(f"λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨: {e}")
return False
return True
def generate(self, message, history, max_tokens=1000, temperature=0.7):
if not self.load_model():
return "λͺ¨λΈ λ‘œλ”©μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€."
try:
# μ±„νŒ… νžˆμŠ€ν† λ¦¬ ꡬ성
conversation = []
for human, assistant in history:
conversation.append({"role": "user", "content": human})
if assistant:
conversation.append({"role": "assistant", "content": assistant})
conversation.append({"role": "user", "content": message})
# 토큰화
input_ids = self.tokenizer.apply_chat_template(
conversation,
return_tensors="pt",
add_generation_prompt=True
)
if torch.cuda.is_available():
input_ids = input_ids.to("cuda")
# 생성
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(
outputs[0][input_ids.shape[-1]:],
skip_special_tokens=True
)
return response
except Exception as e:
return f"생성 쀑 였λ₯˜ λ°œμƒ: {str(e)}"
finally:
# λ©”λͺ¨λ¦¬ 정리
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# λͺ¨λΈ λ§€λ‹ˆμ € μΈμŠ€ν„΄μŠ€
model_manager = ModelManager()
def chat_fn(message, history, max_tokens, temperature):
if not message.strip():
return history, ""
# μ‚¬μš©μž λ©”μ‹œμ§€ μΆ”κ°€
history.append([message, "생성 쀑..."])
# 봇 응닡 생성
response = model_manager.generate(message, history[:-1], max_tokens, temperature)
history[-1][1] = response
return history, ""
# Gradio μΈν„°νŽ˜μ΄μŠ€
with gr.Blocks(title="Command R+ Chat") as demo:
gr.Markdown("""
# πŸ€– Command R+ 4bit μ±„νŒ…λ΄‡
Cohere의 Command R+ 4bit μ–‘μžν™” λͺ¨λΈκ³Ό λŒ€ν™”ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
⚠️ 첫 μ‹€ν–‰ μ‹œ λͺ¨λΈ λ‘œλ”©μ— μ‹œκ°„μ΄ 걸릴 수 μžˆμŠ΅λ‹ˆλ‹€.
""")
chatbot = gr.Chatbot(
height=500,
show_label=False,
show_copy_button=True
)
with gr.Row():
msg = gr.Textbox(
label="λ©”μ‹œμ§€ μž…λ ₯",
placeholder="Command R+μ—κ²Œ μ§ˆλ¬Έν•˜μ„Έμš”...",
lines=2,
scale=4
)
submit = gr.Button("전솑 πŸ“€", variant="primary", scale=1)
with gr.Row():
clear = gr.Button("λŒ€ν™” μ΄ˆκΈ°ν™” πŸ—‘οΈ")
with gr.Accordion("κ³ κΈ‰ μ„€μ •", open=False):
max_tokens = gr.Slider(
minimum=100,
maximum=2000,
value=1000,
step=100,
label="μ΅œλŒ€ 토큰 수"
)
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.1,
label="Temperature (μ°½μ˜μ„±)"
)
# 이벀트 ν•Έλ“€λŸ¬
msg.submit(
chat_fn,
[msg, chatbot, max_tokens, temperature],
[chatbot, msg]
)
submit.click(
chat_fn,
[msg, chatbot, max_tokens, temperature],
[chatbot, msg]
)
clear.click(lambda: ([], ""), outputs=[chatbot, msg])
if __name__ == "__main__":
demo.launch()