hehetest / app.py
hewoo's picture
Update app.py
1dc17cb verified
raw
history blame
2.59 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import Chroma
import gc
import psutil
# ๋ชจ๋ธ ID (๊ณต๊ฐœ๋œ ๋ชจ๋ธ์ด์–ด์•ผ ํ•จ)
model_id = "hewoo/hehehehe"
# ๋ฉ”๋ชจ๋ฆฌ ๋ชจ๋‹ˆํ„ฐ๋ง ํ•จ์ˆ˜
def monitor_memory():
memory_info = psutil.virtual_memory()
st.write(f"ํ˜„์žฌ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {memory_info.percent}%")
# ์บ์‹œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋ฐ ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id)
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, top_p=0.85, top_k=40, repetition_penalty=1.2)
# ์‚ฌ์šฉ์ž ์ •์˜ ์ž„๋ฒ ๋”ฉ ํด๋ž˜์Šค
class CustomEmbedding:
def __init__(self, model):
self.model = model
def embed_query(self, text):
return self.model.encode(text, convert_to_tensor=True).tolist()
def embed_documents(self, texts):
return [self.model.encode(text, convert_to_tensor=True).tolist() for text in texts]
# ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋ฐ ๋ฒกํ„ฐ ์Šคํ† ์–ด ์„ค์ •
@st.cache_resource
def load_embedding_model():
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
@st.cache_resource
def load_vectorstore(embedding_model):
embedding_function = CustomEmbedding(embedding_model)
return Chroma(persist_directory="./chroma_batch_vectors", embedding_function=embedding_function)
# ์งˆ๋ฌธ์— ๋Œ€ํ•œ ์‘๋‹ต ์ƒ์„ฑ ํ•จ์ˆ˜
def generate_response(user_input):
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
search_results = retriever.get_relevant_documents(user_input)
context = "\n".join([result.page_content for result in search_results])
input_text = f"๋งฅ๋ฝ: {context}\n์งˆ๋ฌธ: {user_input}"
response = pipe(input_text)[0]["generated_text"]
return response
# ๋ชจ๋ธ ๋ฐ ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ
pipe = load_model()
embedding_model = load_embedding_model()
vectorstore = load_vectorstore(embedding_model)
# Streamlit ์•ฑ UI
st.title("์ฑ—๋ด‡ ๋ฐ๋ชจ")
st.write("Llama 3.2-3B ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ์ฑ—๋ด‡์ž…๋‹ˆ๋‹ค. ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•ด ์ฃผ์„ธ์š”.")
monitor_memory() # ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ํ™•์ธ
# ์‚ฌ์šฉ์ž ์ž…๋ ฅ ๋ฐ›๊ธฐ
user_input = st.text_input("์งˆ๋ฌธ")
if user_input:
response = generate_response(user_input)
st.write("์ฑ—๋ด‡ ์‘๋‹ต:", response)
monitor_memory() # ๋ฉ”๋ชจ๋ฆฌ ์ƒํƒœ ์—…๋ฐ์ดํŠธ
# ๋ฉ”๋ชจ๋ฆฌ ํ•ด์ œ
del response
gc.collect()