RustX commited on
Commit
a9d3fa8
Β·
1 Parent(s): 1cf7a10

Create chatbot.py

Browse files
Files changed (1) hide show
  1. modules/chatbot.py +49 -0
modules/chatbot.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.chat_models import ChatOpenAI
3
+ from langchain.chains import ConversationalRetrievalChain
4
+ from langchain.prompts.prompt import PromptTemplate
5
+
6
+
7
+ class Chatbot:
8
+ _template = """λ‹€μŒ λŒ€ν™”μ™€ 후속 질문이 주어지면 후속 μ§ˆλ¬Έμ„ λ…λ¦½ν˜• 질문으둜 λ°”κΎΈμ‹­μ‹œμ˜€.
9
+ 질문이 CSV 파일의 정보에 κ΄€ν•œ 것이라고 κ°€μ •ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
10
+ Chat History:
11
+ {chat_history}
12
+ Follow-up entry: {question}
13
+ Standalone question:"""
14
+
15
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
16
+
17
+ qa_template = """"csv 파일의 정보λ₯Ό 기반으둜 μ§ˆλ¬Έμ— λ‹΅ν•˜λŠ” AI λŒ€ν™” λΉ„μ„œμž…λ‹ˆλ‹€.
18
+ csv 파일의 데이터와 질문이 제곡되며 μ‚¬μš©μžκ°€ ν•„μš”ν•œ 정보λ₯Ό 찾도둝 도와야 ν•©λ‹ˆλ‹€.
19
+ μ•Œκ³  μžˆλŠ” 정보에 λŒ€ν•΄μ„œλ§Œ μ‘λ‹΅ν•˜μ‹­μ‹œμ˜€. 닡을 지어내렀고 ν•˜μ§€ λ§ˆμ„Έμš”.
20
+ κ·€ν•˜μ˜ 닡변은 짧고 μΉœκ·Όν•˜λ©° λ™μΌν•œ μ–Έμ–΄λ‘œ μž‘μ„±λ˜μ–΄μ•Ό ν•©λ‹ˆλ‹€.
21
+ question: {question}
22
+ =========
23
+ {context}
24
+ =======
25
+ """
26
+
27
+ QA_PROMPT = PromptTemplate(template=qa_template, input_variables=["question", "context"])
28
+
29
+ def __init__(self, model_name, temperature, vectors):
30
+ self.model_name = model_name
31
+ self.temperature = temperature
32
+ self.vectors = vectors
33
+
34
+ def conversational_chat(self, query):
35
+ """
36
+ Starts a conversational chat with a model via Langchain
37
+ """
38
+
39
+ chain = ConversationalRetrievalChain.from_llm(
40
+ llm=ChatOpenAI(model_name=self.model_name, temperature=self.temperature),
41
+ condense_question_prompt=self.CONDENSE_QUESTION_PROMPT,
42
+ qa_prompt=self.QA_PROMPT,
43
+ retriever=self.vectors.as_retriever(),
44
+ )
45
+ result = chain({"question": query, "chat_history": st.session_state["history"]})
46
+
47
+ st.session_state["history"].append((query, result["answer"]))
48
+
49
+ return result["answer"]