hehetest / app.py
hewoo's picture
Update app.py
ee37e7f verified
raw
history blame
2.24 kB
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import Chroma
import os
# Hugging Face λͺ¨λΈ ID
model_id = "hewoo/hehehehe"
token = os.getenv("HF_API_TOKEN") # ν•„μš”ν•œ 경우 μ‚¬μš©μžμ—κ²Œ Hugging Face API 토큰 μž…λ ₯을 μš”μ²­ν•  수 있음
# λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ € λ‘œλ“œ
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
# ν…μŠ€νŠΈ 생성 νŒŒμ΄ν”„λΌμΈ μ„€μ •
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, top_p=0.85, top_k=40, repetition_penalty=1.2)
# μ‚¬μš©μž μ •μ˜ μž„λ² λ”© 클래슀 생성
class CustomEmbedding:
def __init__(self, model):
self.model = model
def embed_query(self, text):
return self.model.encode(text, convert_to_tensor=True).tolist()
def embed_documents(self, texts):
return [self.model.encode(text, convert_to_tensor=True).tolist() for text in texts]
# μž„λ² λ”© λͺ¨λΈ 및 벑터 μŠ€ν† μ–΄ μ„€μ •
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embedding_function = CustomEmbedding(embedding_model)
# Chroma 벑터 μŠ€ν† μ–΄ μ„€μ •
persist_directory = "./chroma_batch_vectors" # Spaces ν™˜κ²½μ— 맞게 μ‘°μ • ν•„μš”
vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
# μ§ˆλ¬Έμ— λŒ€ν•œ 응닡 생성 ν•¨μˆ˜
def generate_response(user_input):
search_results = retriever.get_relevant_documents(user_input)
context = "\n".join([result.page_content for result in search_results])
input_text = f"λ§₯락: {context}\n질문: {user_input}"
response = pipe(input_text)[0]["generated_text"]
return response
# Streamlit μ•± UI
st.title("챗봇 test")
st.write("Llama 3.2-3B λͺ¨λΈμ„ μ‚¬μš©ν•œ μ±—λ΄‡μž…λ‹ˆλ‹€. μ§ˆλ¬Έμ„ μž…λ ₯ν•΄ μ£Όμ„Έμš”.")
# μ‚¬μš©μž μž…λ ₯ λ°›κΈ°
user_input = st.text_input("질문")
if user_input:
response = generate_response(user_input)
st.write("챗봇 응닡:", response)