|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.vectorstores import Chroma |
|
import gc |
|
import psutil |
|
|
|
|
|
model_id = "hewoo/hehehehe" |
|
|
|
|
|
def monitor_memory(): |
|
memory_info = psutil.virtual_memory() |
|
st.write(f"ํ์ฌ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋: {memory_info.percent}%") |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, top_p=0.85, top_k=40, repetition_penalty=1.2) |
|
|
|
|
|
class CustomEmbedding: |
|
def __init__(self, model): |
|
self.model = model |
|
|
|
def embed_query(self, text): |
|
return self.model.encode(text, convert_to_tensor=True).tolist() |
|
|
|
def embed_documents(self, texts): |
|
return [self.model.encode(text, convert_to_tensor=True).tolist() for text in texts] |
|
|
|
|
|
@st.cache_resource |
|
def load_embedding_model(): |
|
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
@st.cache_resource |
|
def load_vectorstore(embedding_model): |
|
embedding_function = CustomEmbedding(embedding_model) |
|
return Chroma(persist_directory="./chroma_batch_vectors", embedding_function=embedding_function) |
|
|
|
|
|
def generate_response(user_input): |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) |
|
search_results = retriever.get_relevant_documents(user_input) |
|
context = "\n".join([result.page_content for result in search_results]) |
|
input_text = f"๋งฅ๋ฝ: {context}\n์ง๋ฌธ: {user_input}" |
|
response = pipe(input_text)[0]["generated_text"] |
|
return response |
|
|
|
|
|
pipe = load_model() |
|
embedding_model = load_embedding_model() |
|
vectorstore = load_vectorstore(embedding_model) |
|
|
|
|
|
st.title("์ฑ๋ด ๋ฐ๋ชจ") |
|
st.write("Llama 3.2-3B ๋ชจ๋ธ์ ์ฌ์ฉํ ์ฑ๋ด์
๋๋ค. ์ง๋ฌธ์ ์
๋ ฅํด ์ฃผ์ธ์.") |
|
|
|
monitor_memory() |
|
|
|
|
|
user_input = st.text_input("์ง๋ฌธ") |
|
if user_input: |
|
response = generate_response(user_input) |
|
st.write("์ฑ๋ด ์๋ต:", response) |
|
monitor_memory() |
|
|
|
|
|
del response |
|
gc.collect() |
|
|
|
|