wakeupmh commited on
Commit
99637f2
·
1 Parent(s): f1586e3

fix: faiss error

Browse files
Files changed (2) hide show
  1. app.py +12 -18
  2. faiss.index.py +36 -0
app.py CHANGED
@@ -23,32 +23,23 @@ def fetch_arxiv_papers(query, max_results=5):
23
  papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results]
24
  return papers
25
 
 
 
 
 
26
  # RAG Pipeline
27
- def rag_pipeline(query, papers):
28
  # Load pre-trained RAG model
29
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
30
- retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom")
31
  model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
32
 
33
- # Encode papers into embeddings
34
- embedder = SentenceTransformer('all-MiniLM-L6-v2')
35
- paper_embeddings = embedder.encode([paper["summary"] for paper in papers])
36
-
37
- # Build FAISS index
38
- index = faiss.IndexFlatL2(paper_embeddings.shape[1])
39
- index.add(paper_embeddings)
40
-
41
- # Retrieve relevant papers
42
- query_embedding = embedder.encode([query])
43
- distances, indices = index.search(query_embedding, k=2) # Top 2 relevant papers
44
- relevant_papers = [papers[i] for i in indices[0]]
45
-
46
  # Generate answer using RAG
47
  inputs = tokenizer(query, return_tensors="pt")
48
  generated_ids = model.generate(inputs["input_ids"])
49
  answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
 
51
- return answer, relevant_papers
52
 
53
  # Run the app
54
  if query:
@@ -56,14 +47,17 @@ if query:
56
  papers = fetch_arxiv_papers(query)
57
  st.write(f"Found {len(papers)} papers.")
58
 
 
 
 
59
  st.write("Running RAG pipeline...")
60
- answer, relevant_papers = rag_pipeline(query, papers)
61
 
62
  st.write("### Answer:")
63
  st.write(answer)
64
 
65
  st.write("### Relevant Papers:")
66
- for paper in relevant_papers:
67
  st.write(f"**Title:** {paper['title']}")
68
  st.write(f"**Summary:** {paper['summary']}")
69
  st.write(f"**PDF URL:** {paper['pdf_url']}")
 
23
  papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results]
24
  return papers
25
 
26
+ # Load FAISS index
27
+ def load_faiss_index(index_file="faiss_index.index"):
28
+ return faiss.read_index(index_file)
29
+
30
  # RAG Pipeline
31
+ def rag_pipeline(query, papers, index):
32
  # Load pre-trained RAG model
33
  tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
34
+ retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom", passages=papers, index=index)
35
  model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Generate answer using RAG
38
  inputs = tokenizer(query, return_tensors="pt")
39
  generated_ids = model.generate(inputs["input_ids"])
40
  answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
41
 
42
+ return answer
43
 
44
  # Run the app
45
  if query:
 
47
  papers = fetch_arxiv_papers(query)
48
  st.write(f"Found {len(papers)} papers.")
49
 
50
+ st.write("Loading FAISS index...")
51
+ index = load_faiss_index()
52
+
53
  st.write("Running RAG pipeline...")
54
+ answer = rag_pipeline(query, papers, index)
55
 
56
  st.write("### Answer:")
57
  st.write(answer)
58
 
59
  st.write("### Relevant Papers:")
60
+ for paper in papers:
61
  st.write(f"**Title:** {paper['title']}")
62
  st.write(f"**Summary:** {paper['summary']}")
63
  st.write(f"**PDF URL:** {paper['pdf_url']}")
faiss.index.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import faiss
3
+ from sentence_transformers import SentenceTransformer
4
+ import arxiv
5
+
6
+ # Fetch arXiv papers
7
+ def fetch_arxiv_papers(query, max_results=10):
8
+ client = arxiv.Client()
9
+ search = arxiv.Search(
10
+ query=query,
11
+ max_results=max_results,
12
+ sort_by=arxiv.SortCriterion.SubmittedDate
13
+ )
14
+ results = list(client.results(search))
15
+ papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results]
16
+ return papers
17
+
18
+ # Build and save FAISS index
19
+ def build_faiss_index(papers, index_file="faiss_index.index"):
20
+ embedder = SentenceTransformer('all-MiniLM-L6-v2')
21
+ paper_embeddings = embedder.encode([paper["summary"] for paper in papers])
22
+
23
+ # Create FAISS index
24
+ dimension = paper_embeddings.shape[1]
25
+ index = faiss.IndexFlatL2(dimension)
26
+ index.add(paper_embeddings)
27
+
28
+ # Save index to disk
29
+ faiss.write_index(index, index_file)
30
+ print(f"FAISS index saved to {index_file}")
31
+
32
+ # Example usage
33
+ if __name__ == "__main__":
34
+ query = "quantum computing"
35
+ papers = fetch_arxiv_papers(query)
36
+ build_faiss_index(papers)