wakeupmh commited on
Commit
f1586e3
·
1 Parent(s): 2a0dd15

feat: first commit

Browse files
Files changed (2) hide show
  1. app.py +70 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
3
+ from sentence_transformers import SentenceTransformer
4
+ import faiss
5
+ import numpy as np
6
+ import arxiv
7
+
8
+ # Title
9
+ st.title("arXiv RAG with Streamlit")
10
+
11
+ # Input: Query
12
+ query = st.text_input("Enter your query:")
13
+
14
+ # Fetch arXiv papers
15
+ def fetch_arxiv_papers(query, max_results=5):
16
+ client = arxiv.Client()
17
+ search = arxiv.Search(
18
+ query=query,
19
+ max_results=max_results,
20
+ sort_by=arxiv.SortCriterion.SubmittedDate
21
+ )
22
+ results = list(client.results(search))
23
+ papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results]
24
+ return papers
25
+
26
+ # RAG Pipeline
27
+ def rag_pipeline(query, papers):
28
+ # Load pre-trained RAG model
29
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
30
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom")
31
+ model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
32
+
33
+ # Encode papers into embeddings
34
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
35
+ paper_embeddings = embedder.encode([paper["summary"] for paper in papers])
36
+
37
+ # Build FAISS index
38
+ index = faiss.IndexFlatL2(paper_embeddings.shape[1])
39
+ index.add(paper_embeddings)
40
+
41
+ # Retrieve relevant papers
42
+ query_embedding = embedder.encode([query])
43
+ distances, indices = index.search(query_embedding, k=2) # Top 2 relevant papers
44
+ relevant_papers = [papers[i] for i in indices[0]]
45
+
46
+ # Generate answer using RAG
47
+ inputs = tokenizer(query, return_tensors="pt")
48
+ generated_ids = model.generate(inputs["input_ids"])
49
+ answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
+
51
+ return answer, relevant_papers
52
+
53
+ # Run the app
54
+ if query:
55
+ st.write("Fetching arXiv papers...")
56
+ papers = fetch_arxiv_papers(query)
57
+ st.write(f"Found {len(papers)} papers.")
58
+
59
+ st.write("Running RAG pipeline...")
60
+ answer, relevant_papers = rag_pipeline(query, papers)
61
+
62
+ st.write("### Answer:")
63
+ st.write(answer)
64
+
65
+ st.write("### Relevant Papers:")
66
+ for paper in relevant_papers:
67
+ st.write(f"**Title:** {paper['title']}")
68
+ st.write(f"**Summary:** {paper['summary']}")
69
+ st.write(f"**PDF URL:** {paper['pdf_url']}")
70
+ st.write("---")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ datasets
4
+ sentence-transformers
5
+ faiss-cpu
6
+ arxiv