Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import gc | |
class ModelManager: | |
def __init__(self): | |
self.model = None | |
self.tokenizer = None | |
self.model_name = "CohereForAI/c4ai-command-r-plus-4bit" | |
def load_model(self): | |
if self.model is None: | |
try: | |
print("λͺ¨λΈ λ‘λ© μ€... μκ°μ΄ 걸릴 μ μμ΅λλ€.") | |
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
self.model_name, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
trust_remote_code=True, | |
load_in_4bit=True, | |
low_cpu_mem_usage=True | |
) | |
print("λͺ¨λΈ λ‘λ© μλ£!") | |
return True | |
except Exception as e: | |
print(f"λͺ¨λΈ λ‘λ© μ€ν¨: {e}") | |
return False | |
return True | |
def generate(self, message, history, max_tokens=1000, temperature=0.7): | |
if not self.load_model(): | |
return "λͺ¨λΈ λ‘λ©μ μ€ν¨νμ΅λλ€." | |
try: | |
# μ±ν νμ€ν 리 κ΅¬μ± | |
conversation = [] | |
for human, assistant in history: | |
conversation.append({"role": "user", "content": human}) | |
if assistant: | |
conversation.append({"role": "assistant", "content": assistant}) | |
conversation.append({"role": "user", "content": message}) | |
# ν ν°ν | |
input_ids = self.tokenizer.apply_chat_template( | |
conversation, | |
return_tensors="pt", | |
add_generation_prompt=True | |
) | |
if torch.cuda.is_available(): | |
input_ids = input_ids.to("cuda") | |
# μμ± | |
with torch.no_grad(): | |
outputs = self.model.generate( | |
input_ids, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
do_sample=True, | |
pad_token_id=self.tokenizer.eos_token_id, | |
eos_token_id=self.tokenizer.eos_token_id | |
) | |
response = self.tokenizer.decode( | |
outputs[0][input_ids.shape[-1]:], | |
skip_special_tokens=True | |
) | |
return response | |
except Exception as e: | |
return f"μμ± μ€ μ€λ₯ λ°μ: {str(e)}" | |
finally: | |
# λ©λͺ¨λ¦¬ μ 리 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
# λͺ¨λΈ λ§€λμ μΈμ€ν΄μ€ | |
model_manager = ModelManager() | |
def chat_fn(message, history, max_tokens, temperature): | |
if not message.strip(): | |
return history, "" | |
# μ¬μ©μ λ©μμ§ μΆκ° | |
history.append([message, "μμ± μ€..."]) | |
# λ΄ μλ΅ μμ± | |
response = model_manager.generate(message, history[:-1], max_tokens, temperature) | |
history[-1][1] = response | |
return history, "" | |
# Gradio μΈν°νμ΄μ€ | |
with gr.Blocks(title="Command R+ Chat") as demo: | |
gr.Markdown(""" | |
# π€ Command R+ 4bit μ±ν λ΄ | |
Cohereμ Command R+ 4bit μμν λͺ¨λΈκ³Ό λνν μ μμ΅λλ€. | |
β οΈ μ²« μ€ν μ λͺ¨λΈ λ‘λ©μ μκ°μ΄ 걸릴 μ μμ΅λλ€. | |
""") | |
chatbot = gr.Chatbot( | |
height=500, | |
show_label=False, | |
show_copy_button=True | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
label="λ©μμ§ μ λ ₯", | |
placeholder="Command R+μκ² μ§λ¬ΈνμΈμ...", | |
lines=2, | |
scale=4 | |
) | |
submit = gr.Button("μ μ‘ π€", variant="primary", scale=1) | |
with gr.Row(): | |
clear = gr.Button("λν μ΄κΈ°ν ποΈ") | |
with gr.Accordion("κ³ κΈ μ€μ ", open=False): | |
max_tokens = gr.Slider( | |
minimum=100, | |
maximum=2000, | |
value=1000, | |
step=100, | |
label="μ΅λ ν ν° μ" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature (μ°½μμ±)" | |
) | |
# μ΄λ²€νΈ νΈλ€λ¬ | |
msg.submit( | |
chat_fn, | |
[msg, chatbot, max_tokens, temperature], | |
[chatbot, msg] | |
) | |
submit.click( | |
chat_fn, | |
[msg, chatbot, max_tokens, temperature], | |
[chatbot, msg] | |
) | |
clear.click(lambda: ([], ""), outputs=[chatbot, msg]) | |
if __name__ == "__main__": | |
demo.launch() |