File size: 3,011 Bytes
cc0604c 6ab87db a91fe67 cc0604c a91fe67 1dc17cb cc0604c a91fe67 1dc17cb a91fe67 cc0604c 1dc17cb c7f958e a91fe67 c7f958e a91fe67 c7f958e a91fe67 1dc17cb 6ab87db c7f958e 1dc17cb a91fe67 c4f01b6 1dc17cb cc0604c 1dc17cb cc0604c 1dc17cb cc0604c 6ab87db a91fe67 cc0604c 1dc17cb cc0604c 6ab87db a91fe67 1dc17cb cc0604c a91fe67 cc0604c a91fe67 1dc17cb a91fe67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import Chroma
import gc
import psutil
# ๋ชจ๋ธ ID (๊ณต๊ฐ๋ ๋ชจ๋ธ์ด์ด์ผ ํจ)
model_id = "hewoo/hehehehe"
# ๋ฉ๋ชจ๋ฆฌ ๋ชจ๋ํฐ๋ง ํจ์
def monitor_memory():
memory_info = psutil.virtual_memory()
st.write(f"ํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {memory_info.percent}%")
# ์บ์๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ๋ฐ ํ์ดํ๋ผ์ธ ๋ก๋
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, top_p=0.85, top_k=40, repetition_penalty=1.2)
# ์ฌ์ฉ์ ์ ์ ์๋ฒ ๋ฉ ํด๋์ค
class CustomEmbedding:
def __init__(self, model):
self.model = model
def embed_query(self, text):
return self.model.encode(text, convert_to_tensor=False).tolist()
def embed_documents(self, texts):
return [self.model.encode(text, convert_to_tensor=False).tolist() for text in texts]
# ํ๊ตญ์ด ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ฐ ๋ฒกํฐ ์คํ ์ด ์ค์
@st.cache_resource
def load_embedding_model():
return SentenceTransformer("jhgan/ko-sroberta-multitask")
@st.cache_resource
def load_vectorstore(_embedding_model): # _embedding_model๋ก ์ด๋ฆ ๋ณ๊ฒฝ
embedding_function = CustomEmbedding(_embedding_model)
return Chroma(persist_directory="./chroma_batch_vectors", embedding_function=embedding_function)
# ์ง๋ฌธ์ ๋ํ ์๋ต ์์ฑ ํจ์
def generate_response(user_input):
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
search_results = retriever.get_relevant_documents(user_input)
context = "\n".join([result.page_content for result in search_results])
prompt = f"""๋ค์์ ์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์์ฑํ๋ ํ๊ตญ์ด ์ด์์คํดํธ์
๋๋ค.
์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ํด ์ฃผ์ด์ง ๋งฅ๋ฝ์ ๊ธฐ๋ฐ์ผ๋ก ์ ํํ๊ณ ์์ธํ ๋ต๋ณ์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์.
๋ง์ฝ ๋งฅ๋ฝ์ ๊ด๋ จ ์ ๋ณด๊ฐ ์์ผ๋ฉด, "์ฃ์กํ์ง๋ง ํด๋น ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์ฐพ์ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์.
๋งฅ๋ฝ:
{context}
์ง๋ฌธ:
{user_input}
๋ต๋ณ:"""
response = pipe(prompt)[0]["generated_text"]
return response
# ๋ชจ๋ธ ๋ฐ ์๋ฒ ๋ฉ ๋ชจ๋ธ ๋ก๋
pipe = load_model()
embedding_model = load_embedding_model()
vectorstore = load_vectorstore(embedding_model)
# Streamlit ์ฑ UI
st.title("์ฑ๋ด ๋ฐ๋ชจ")
st.write("Llama 3.2-3B ๋ชจ๋ธ์ ์ฌ์ฉํ ์ฑ๋ด์
๋๋ค. ์ง๋ฌธ์ ์
๋ ฅํด ์ฃผ์ธ์.")
monitor_memory() # ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ ํ์ธ
# ์ฌ์ฉ์ ์
๋ ฅ ๋ฐ๊ธฐ
user_input = st.text_input("์ง๋ฌธ")
if user_input:
response = generate_response(user_input)
st.write("์ฑ๋ด ์๋ต:", response)
monitor_memory() # ๋ฉ๋ชจ๋ฆฌ ์ํ ์
๋ฐ์ดํธ
# ๋ฉ๋ชจ๋ฆฌ ํด์
del response
gc.collect()
|