|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.vectorstores import Chroma |
|
import gc |
|
import psutil |
|
|
|
|
|
model_id = "hewoo/hehehehe" |
|
|
|
|
|
def monitor_memory(): |
|
memory_info = psutil.virtual_memory() |
|
st.write(f"ํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {memory_info.percent}%") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, top_p=0.85, top_k=40, repetition_penalty=1.2) |
|
|
|
|
|
class CustomEmbedding: |
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def embed_query(self, text): |
|
return self.model.encode(text, convert_to_tensor=False).tolist() |
|
|
|
def embed_documents(self, texts): |
|
return [self.model.encode(text, convert_to_tensor=False).tolist() for text in texts] |
|
|
|
|
|
@st.cache_resource |
|
def load_embedding_model(): |
|
return SentenceTransformer("jhgan/ko-sroberta-multitask") |
|
|
|
@st.cache_resource |
|
def load_vectorstore(_embedding_model): |
|
embedding_function = CustomEmbedding(_embedding_model) |
|
return Chroma(persist_directory="./chroma_batch_vectors", embedding_function=embedding_function) |
|
|
|
|
|
def generate_response(user_input): |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) |
|
search_results = retriever.get_relevant_documents(user_input) |
|
context = "\n".join([result.page_content for result in search_results]) |
|
|
|
prompt = f"""๋ค์์ ์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์์ฑํ๋ ํ๊ตญ์ด ์ด์์คํดํธ์
๋๋ค. |
|
์ฌ์ฉ์์ ์ง๋ฌธ์ ๋ํด ์ฃผ์ด์ง ๋งฅ๋ฝ์ ๊ธฐ๋ฐ์ผ๋ก ์ ํํ๊ณ ์์ธํ ๋ต๋ณ์ ํ๊ตญ์ด๋ก ์์ฑํ์ธ์. |
|
๋ง์ฝ ๋งฅ๋ฝ์ ๊ด๋ จ ์ ๋ณด๊ฐ ์์ผ๋ฉด, "์ฃ์กํ์ง๋ง ํด๋น ์ง๋ฌธ์ ๋ํ ๋ต๋ณ์ ์ฐพ์ ์ ์์ต๋๋ค."๋ผ๊ณ ๋ต๋ณํ์ธ์. |
|
|
|
๋งฅ๋ฝ: |
|
{context} |
|
|
|
์ง๋ฌธ: |
|
{user_input} |
|
|
|
๋ต๋ณ:""" |
|
|
|
response = pipe(prompt)[0]["generated_text"] |
|
return response |
|
|
|
|
|
pipe = load_model() |
|
embedding_model = load_embedding_model() |
|
vectorstore = load_vectorstore(embedding_model) |
|
|
|
|
|
st.title("์ฑ๋ด ๋ฐ๋ชจ") |
|
st.write("Llama 3.2-3B ๋ชจ๋ธ์ ์ฌ์ฉํ ์ฑ๋ด์
๋๋ค. ์ง๋ฌธ์ ์
๋ ฅํด ์ฃผ์ธ์.") |
|
|
|
monitor_memory() |
|
|
|
|
|
user_input = st.text_input("์ง๋ฌธ") |
|
if user_input: |
|
response = generate_response(user_input) |
|
st.write("์ฑ๋ด ์๋ต:", response) |
|
monitor_memory() |
|
|
|
|
|
del response |
|
gc.collect() |
|
|
|
|
|
|