hewoo commited on
Commit
c7f958e
ยท
verified ยท
1 Parent(s): 0c5bb50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -2
app.py CHANGED
@@ -15,12 +15,31 @@ model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=token)
15
  # ํ…์ŠคํŠธ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •
16
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, top_p=0.85, top_k=40, repetition_penalty=1.2)
17
 
18
- # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋ฐ ๊ฒ€์ƒ‰ ๊ธฐ๋Šฅ ์„ค์ •
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
20
  persist_directory = "./chroma_batch_vectors"
21
- vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedding_model.encode)
 
 
22
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
23
 
 
24
  # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์š”์•ฝ ํ•จ์ˆ˜
25
  def summarize_results(search_results):
26
  combined_text = "\n".join([result.page_content for result in search_results])
 
15
  # ํ…์ŠคํŠธ ์ƒ์„ฑ ํŒŒ์ดํ”„๋ผ์ธ ์„ค์ •
16
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150, temperature=0.5, top_p=0.85, top_k=40, repetition_penalty=1.2)
17
 
18
+
19
+
20
+ # ์‚ฌ์šฉ์ž ์ •์˜ ์ž„๋ฒ ๋”ฉ ํด๋ž˜์Šค ์ƒ์„ฑ
21
+ class CustomEmbedding:
22
+ def __init__(self, model):
23
+ self.model = model
24
+
25
+ def embed_query(self, text):
26
+ return self.model.encode(text, convert_to_tensor=False).tolist()
27
+
28
+ def embed_documents(self, texts):
29
+ return [self.model.encode(text, convert_to_tensor=False).tolist() for text in texts]
30
+
31
+ # ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ์„ค์ •
32
  embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
33
+ embedding_function = CustomEmbedding(embedding_model)
34
+
35
+ # Chroma ๋ฒกํ„ฐ ์Šคํ† ์–ด ์„ค์ •
36
  persist_directory = "./chroma_batch_vectors"
37
+ vectorstore = Chroma(persist_directory=persist_directory, embedding_function=embedding_function)
38
+
39
+ # ๊ฒ€์ƒ‰ ๊ธฐ๋Šฅ ์„ค์ •
40
  retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
41
 
42
+
43
  # ๊ฒ€์ƒ‰ ๊ฒฐ๊ณผ ์š”์•ฝ ํ•จ์ˆ˜
44
  def summarize_results(search_results):
45
  combined_text = "\n".join([result.page_content for result in search_results])