import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from sentence_transformers import SentenceTransformer from langchain.vectorstores import Chroma import gc import psutil # 모델 ID (공개된 모델이어야 함) model_id = "hewoo/hehehehe" # 메모리 모니터링 함수 def monitor_memory(): memory_info = psutil.virtual_memory() st.write(f"현재 메모리 사용량: {memory_info.percent}%") # 캐시를 사용하여 모델 및 파이프라인 로드 @st.cache_resource def load_model(): tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained(model_id) return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, top_p=0.85, top_k=40, repetition_penalty=1.2) # 사용자 정의 임베딩 클래스 class CustomEmbedding: def __init__(self, model): self.model = model def embed_query(self, text): return self.model.encode(text, convert_to_tensor=True).tolist() def embed_documents(self, texts): return [self.model.encode(text, convert_to_tensor=True).tolist() for text in texts] # 임베딩 모델 및 벡터 스토어 설정 @st.cache_resource def load_embedding_model(): return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") @st.cache_resource def load_vectorstore(embedding_model): embedding_function = CustomEmbedding(embedding_model) return Chroma(persist_directory="./chroma_batch_vectors", embedding_function=embedding_function) # 질문에 대한 응답 생성 함수 def generate_response(user_input): retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) search_results = retriever.get_relevant_documents(user_input) context = "\n".join([result.page_content for result in search_results]) input_text = f"맥락: {context}\n질문: {user_input}" response = pipe(input_text)[0]["generated_text"] return response # 모델 및 임베딩 모델 로드 pipe = load_model() embedding_model = load_embedding_model() vectorstore = load_vectorstore(embedding_model) # Streamlit 앱 UI st.title("챗봇 데모") st.write("Llama 3.2-3B 모델을 사용한 챗봇입니다. 질문을 입력해 주세요.") monitor_memory() # 메모리 사용량 확인 # 사용자 입력 받기 user_input = st.text_input("질문") if user_input: response = generate_response(user_input) st.write("챗봇 응답:", response) monitor_memory() # 메모리 상태 업데이트 # 메모리 해제 del response gc.collect()