import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer import os # --- 모델 로드 --- # 모델 경로 설정 (Hugging Face 모델 ID) model_id = "microsoft/bitnet-b1.58-2B-4T" # 모델 로드 시 경고 메시지를 최소화하기 위해 로깅 레벨 설정 os.environ["TRANSFORMERS_VERBOSITY"] = "error" # AutoModelForCausalLM과 AutoTokenizer를 로드합니다. # BitNet 모델은 trust_remote_code=True가 필요합니다. # GitHub 특정 브랜치에서 설치한 transformers를 사용합니다. try: print(f"모델 로딩 중: {model_id}...") # GPU가 사용 가능하면 bf16 사용 if torch.cuda.is_available(): # torch_dtype을 명시적으로 설정하여 로드 오류 방지 시도 model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, trust_remote_code=True ).to("cuda") # GPU로 모델 이동 tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) print("GPU를 사용하여 모델 로드 완료.") else: # CPU 사용 시 torch_dtype 생략 또는 float32 model = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) print("CPU를 사용하여 모델 로드 완료. 성능이 느릴 수 있습니다.") except Exception as e: print(f"모델 로드 중 오류 발생: {e}") tokenizer = None model = None print("모델 로드에 실패했습니다. 애플리케이션이 제대로 동작하지 않을 수 있습니다.") # --- 텍스트 생성 함수 --- def generate_text(prompt, max_length=100, temperature=0.7): if model is None or tokenizer is None: return "모델 로드에 실패하여 텍스트 생성을 할 수 없습니다." try: # 프롬프트 토큰화 inputs = tokenizer(prompt, return_tensors="pt") # GPU 사용 가능 시 GPU로 입력 이동 if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} # 텍스트 생성 # LLaMA 3 토크나이저를 사용하므로 chat template 적용 가능 (선택 사항) # 메시지 형식을 사용하지 않고 직접 프롬프트 입력 시 아래 코드 사용 outputs = model.generate( **inputs, max_new_tokens=max_length, temperature=temperature, do_sample=True, # 샘플링 활성화 pad_token_id=tokenizer.eos_token_id # 패딩 토큰 ID 설정 (필요시) ) # 생성된 텍스트 디코딩 # 입력 프롬프트 부분을 제외하고 생성된 부분만 디코딩 generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True) return generated_text except Exception as e: return f"텍스트 생성 중 오류 발생: {e}" # --- Gradio 인터페이스 설정 --- if model is not None and tokenizer is not None: interface = gr.Interface( fn=generate_text, inputs=[ gr.Textbox(lines=2, placeholder="텍스트를 입력하세요...", label="입력 프롬프트"), gr.Slider(minimum=10, maximum=500, value=100, label="최대 생성 길이"), gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (창의성)") ], outputs=gr.Textbox(label="생성된 텍스트"), title="BitNet b1.58-2B-4T 텍스트 생성 데모", description="BitNet b1.58-2B-4T 모델을 사용하여 텍스트를 생성합니다." ) # Gradio 앱 실행 # Hugging Face Spaces에서는 share=True가 자동으로 설정됩니다. interface.launch() else: print("모델 로드 실패로 인해 Gradio 인터페이스를 실행할 수 없습니다.")