chahah commited on
Commit
bde2b54
·
verified ·
1 Parent(s): 1b0478e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -62
app.py CHANGED
@@ -22,62 +22,12 @@ rate_limiter = InMemoryRateLimiter(
22
  check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request,
23
  max_bucket_size=10, # Controls the maximum burst size.
24
  )
25
- """
26
- # get data
27
- urlsfile = open("urls.txt")
28
- urls = urlsfile.readlines()
29
- urls = [url.replace("\n","") for url in urls]
30
- urlsfile.close()
31
-
32
- # Load, chunk and index the contents of the blog.
33
- loader = WebBaseLoader(urls)
34
- docs = loader.load()
35
-
36
- # load arxiv papers
37
- arxivfile = open("arxiv.txt")
38
- arxivs = arxivfile.readlines()
39
- arxivs = [arxiv.replace("\n","") for arxiv in arxivs]
40
- arxivfile.close()
41
 
42
  retriever = ArxivRetriever(
43
  load_max_docs=2,
44
  get_ful_documents=True,
45
  )
46
 
47
- for arxiv in arxivs:
48
- doc = retriever.invoke(arxiv)
49
- doc[0].metadata['Published'] = str(doc[0].metadata['Published'])
50
- docs.append(doc[0])
51
-
52
-
53
- def format_docs(docs):
54
- return "\n\n".join(doc.page_content for doc in docs)
55
-
56
- def RAG(llm, docs, embeddings):
57
-
58
- # Split text
59
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
60
- splits = text_splitter.split_documents(docs)
61
-
62
- # Create vector store
63
- vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
64
-
65
- # Retrieve and generate using the relevant snippets of the documents
66
- retriever = vectorstore.as_retriever()
67
-
68
- # Prompt basis example for RAG systems
69
- prompt = hub.pull("rlm/rag-prompt")
70
-
71
- # Create the chain
72
- rag_chain = (
73
- {"context": retriever | format_docs, "question": RunnablePassthrough()}
74
- | prompt
75
- | llm
76
- | StrOutputParser()
77
- )
78
-
79
- return rag_chain
80
-
81
  # LLM model
82
  llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)
83
 
@@ -87,10 +37,48 @@ embed_model = "sentence-transformers/multi-qa-distilbert-cos-v1"
87
  embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)
88
  # embeddings = MistralAIEmbeddings()
89
 
90
- # RAG chain
91
- rag_chain = RAG(llm, docs, embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- def handle_prompt(message, history):
 
 
 
 
 
94
  try:
95
  # Stream output
96
  out=""
@@ -99,17 +87,19 @@ def handle_prompt(message, history):
99
  yield out
100
  except:
101
  raise gr.Error("Requests rate limit exceeded")
102
- """
103
 
104
- def handle_prompt(message, history, input1):
105
- return f"arxiv code: {input1}, {message}"
106
 
 
107
 
108
- greetingsmessage = "Hi, I'm your personal arXiv reader. Input the arXiv number of the paper:"
109
 
110
- demo = gr.ChatInterface(handle_prompt, type="messages", theme=gr.themes.Soft(),
111
- description=greetingsmessage,
112
- additional_inputs=[gr.Textbox("", label="arxiv.code")]
113
- )
 
 
 
114
 
115
- demo.launch()
 
 
22
  check_every_n_seconds=0.01, # Wake up every 100 ms to check whether allowed to make a request,
23
  max_bucket_size=10, # Controls the maximum burst size.
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  retriever = ArxivRetriever(
27
  load_max_docs=2,
28
  get_ful_documents=True,
29
  )
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # LLM model
32
  llm = ChatMistralAI(model="mistral-large-latest", rate_limiter=rate_limiter)
33
 
 
37
  embeddings = HuggingFaceInstructEmbeddings(model_name=embed_model)
38
  # embeddings = MistralAIEmbeddings()
39
 
40
+ def initialize(arxivcode):
41
+ docs = retriever.invoke(arxiv)
42
+ docs[0].metadata['Published'] = str(doc[0].metadata['Published'])
43
+
44
+ def format_docs(docs):
45
+ return "\n\n".join(doc.page_content for doc in docs)
46
+
47
+ def RAG(llm, docs, embeddings):
48
+
49
+ # Split text
50
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
51
+ splits = text_splitter.split_documents(docs)
52
+
53
+ # Create vector store
54
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
55
+
56
+ # Retrieve and generate using the relevant snippets of the documents
57
+ retriever = vectorstore.as_retriever()
58
+
59
+ # Prompt basis example for RAG systems
60
+ prompt = hub.pull("rlm/rag-prompt")
61
+
62
+ # Create the chain
63
+ rag_chain = (
64
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
65
+ | prompt
66
+ | llm
67
+ | StrOutputParser()
68
+ )
69
+
70
+ return rag_chain
71
+
72
+ return RAG(llm, docs, embeddings)
73
+
74
+ rag_chain = None
75
 
76
+ def handle_prompt(message, history, arxivcode):
77
+ if rag_chain is None:
78
+ # initialize RAG chain
79
+ # RAG chain
80
+ rag_chain = initialize(arxivcode)
81
+
82
  try:
83
  # Stream output
84
  out=""
 
87
  yield out
88
  except:
89
  raise gr.Error("Requests rate limit exceeded")
 
90
 
 
 
91
 
92
+ greetingsmessage = "Hi, I'm your personal arXiv reader. Ask me questions about the arXiv paper above"
93
 
94
+ with gr.Blocks() as demo:
95
 
96
+ arxiv_code = gr.Textbox("", label="arxiv.number")
97
+
98
+ gr.ChatInterface(handle_prompt, type="messages", theme=gr.themes.Soft(),
99
+ description=greetingsmessage,
100
+ additional_inputs=[arxiv_code]
101
+ )
102
+
103
 
104
+ if __name__ == "__main__":
105
+ demo.launch()