import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch import gc class ModelManager: def __init__(self): self.model = None self.tokenizer = None self.model_name = "CohereForAI/c4ai-command-r-plus-4bit" def load_model(self): if self.model is None: try: print("모델 로딩 중... 시간이 걸릴 수 있습니다.") self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True, load_in_4bit=True, low_cpu_mem_usage=True ) print("모델 로딩 완료!") return True except Exception as e: print(f"모델 로딩 실패: {e}") return False return True def generate(self, message, history, max_tokens=1000, temperature=0.7): if not self.load_model(): return "모델 로딩에 실패했습니다." try: # 채팅 히스토리 구성 conversation = [] for human, assistant in history: conversation.append({"role": "user", "content": human}) if assistant: conversation.append({"role": "assistant", "content": assistant}) conversation.append({"role": "user", "content": message}) # 토큰화 input_ids = self.tokenizer.apply_chat_template( conversation, return_tensors="pt", add_generation_prompt=True ) if torch.cuda.is_available(): input_ids = input_ids.to("cuda") # 생성 with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode( outputs[0][input_ids.shape[-1]:], skip_special_tokens=True ) return response except Exception as e: return f"생성 중 오류 발생: {str(e)}" finally: # 메모리 정리 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() # 모델 매니저 인스턴스 model_manager = ModelManager() def chat_fn(message, history, max_tokens, temperature): if not message.strip(): return history, "" # 사용자 메시지 추가 history.append([message, "생성 중..."]) # 봇 응답 생성 response = model_manager.generate(message, history[:-1], max_tokens, temperature) history[-1][1] = response return history, "" # Gradio 인터페이스 with gr.Blocks(title="Command R+ Chat") as demo: gr.Markdown(""" # 🤖 Command R+ 4bit 채팅봇 Cohere의 Command R+ 4bit 양자화 모델과 대화할 수 있습니다. ⚠️ 첫 실행 시 모델 로딩에 시간이 걸릴 수 있습니다. """) chatbot = gr.Chatbot( height=500, show_label=False, show_copy_button=True ) with gr.Row(): msg = gr.Textbox( label="메시지 입력", placeholder="Command R+에게 질문하세요...", lines=2, scale=4 ) submit = gr.Button("전송 📤", variant="primary", scale=1) with gr.Row(): clear = gr.Button("대화 초기화 🗑️") with gr.Accordion("고급 설정", open=False): max_tokens = gr.Slider( minimum=100, maximum=2000, value=1000, step=100, label="최대 토큰 수" ) temperature = gr.Slider( minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature (창의성)" ) # 이벤트 핸들러 msg.submit( chat_fn, [msg, chatbot, max_tokens, temperature], [chatbot, msg] ) submit.click( chat_fn, [msg, chatbot, max_tokens, temperature], [chatbot, msg] ) clear.click(lambda: ([], ""), outputs=[chatbot, msg]) if __name__ == "__main__": demo.launch()