Spaces:
Sleeping
Sleeping
feat: add dataset
Browse files- app.py +29 -39
- faiss_index/index.py +26 -10
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,46 +1,41 @@
|
|
1 |
import streamlit as st
|
2 |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
3 |
-
from sentence_transformers import SentenceTransformer
|
4 |
import faiss
|
5 |
-
import
|
6 |
-
import
|
7 |
|
8 |
# Title
|
9 |
-
st.title("
|
10 |
|
11 |
# Input: Query
|
12 |
-
query = st.text_input("
|
13 |
|
14 |
-
#
|
15 |
-
def
|
16 |
-
|
17 |
-
search = arxiv.Search(
|
18 |
-
query=query,
|
19 |
-
max_results=max_results,
|
20 |
-
sort_by=arxiv.SortCriterion.SubmittedDate
|
21 |
-
)
|
22 |
-
results = list(client.results(search))
|
23 |
-
papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results]
|
24 |
-
return papers
|
25 |
-
|
26 |
-
# Load FAISS index
|
27 |
-
def load_faiss_index(index_file="faiss_index.index"):
|
28 |
-
import os
|
29 |
-
if not os.path.exists(index_file):
|
30 |
# Import the build function from the other file
|
31 |
import faiss_index.index as faiss_index_index
|
32 |
|
33 |
# Fetch some initial papers to build the index
|
34 |
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
|
35 |
-
faiss_index_index.build_faiss_index(initial_papers,
|
36 |
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
|
39 |
# RAG Pipeline
|
40 |
-
def rag_pipeline(query,
|
41 |
-
# Load pre-trained RAG model
|
42 |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
43 |
-
retriever = RagRetriever.from_pretrained(
|
|
|
|
|
|
|
|
|
|
|
44 |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
|
45 |
|
46 |
# Generate answer using RAG
|
@@ -52,22 +47,17 @@ def rag_pipeline(query, papers, index):
|
|
52 |
|
53 |
# Run the app
|
54 |
if query:
|
55 |
-
st.write("
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
st.write("Loading FAISS index...")
|
60 |
-
index = load_faiss_index()
|
61 |
-
|
62 |
st.write("Running RAG pipeline...")
|
63 |
-
answer = rag_pipeline(query,
|
64 |
|
65 |
st.write("### Answer:")
|
66 |
st.write(answer)
|
67 |
|
68 |
-
st.write("###
|
69 |
-
for
|
70 |
-
st.write(f"**Title:** {
|
71 |
-
st.write(f"**Summary:** {
|
72 |
-
st.write(f"**PDF URL:** {paper['pdf_url']}")
|
73 |
st.write("---")
|
|
|
1 |
import streamlit as st
|
2 |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
|
|
|
3 |
import faiss
|
4 |
+
import os
|
5 |
+
from datasets import load_from_disk
|
6 |
|
7 |
# Title
|
8 |
+
st.title("AMA Austim 🧩")
|
9 |
|
10 |
# Input: Query
|
11 |
+
query = st.text_input("Please ask me anything about autism ✨")
|
12 |
|
13 |
+
# Load or create RAG dataset
|
14 |
+
def load_rag_dataset(dataset_dir="rag_dataset"):
|
15 |
+
if not os.path.exists(dataset_dir):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# Import the build function from the other file
|
17 |
import faiss_index.index as faiss_index_index
|
18 |
|
19 |
# Fetch some initial papers to build the index
|
20 |
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100)
|
21 |
+
dataset_dir = faiss_index_index.build_faiss_index(initial_papers, dataset_dir)
|
22 |
|
23 |
+
# Load the dataset and index
|
24 |
+
dataset = load_from_disk(os.path.join(dataset_dir, "dataset"))
|
25 |
+
index = faiss.read_index(os.path.join(dataset_dir, "embeddings.faiss"))
|
26 |
+
|
27 |
+
return dataset, index
|
28 |
|
29 |
# RAG Pipeline
|
30 |
+
def rag_pipeline(query, dataset, index):
|
31 |
+
# Load pre-trained RAG model and configure retriever
|
32 |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
|
33 |
+
retriever = RagRetriever.from_pretrained(
|
34 |
+
"facebook/rag-sequence-nq",
|
35 |
+
index_name="custom",
|
36 |
+
passages_path=os.path.join("rag_dataset", "dataset"),
|
37 |
+
index_path=os.path.join("rag_dataset", "embeddings.faiss")
|
38 |
+
)
|
39 |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever)
|
40 |
|
41 |
# Generate answer using RAG
|
|
|
47 |
|
48 |
# Run the app
|
49 |
if query:
|
50 |
+
st.write("Loading or creating RAG dataset...")
|
51 |
+
dataset, index = load_rag_dataset()
|
52 |
+
|
|
|
|
|
|
|
|
|
53 |
st.write("Running RAG pipeline...")
|
54 |
+
answer = rag_pipeline(query, dataset, index)
|
55 |
|
56 |
st.write("### Answer:")
|
57 |
st.write(answer)
|
58 |
|
59 |
+
st.write("### Retrieved Papers:")
|
60 |
+
for i in range(min(5, len(dataset))):
|
61 |
+
st.write(f"**Title:** {dataset[i]['title']}")
|
62 |
+
st.write(f"**Summary:** {dataset[i]['text'][:200]}...")
|
|
|
63 |
st.write("---")
|
faiss_index/index.py
CHANGED
@@ -2,6 +2,8 @@ import numpy as np
|
|
2 |
import faiss
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
import arxiv
|
|
|
|
|
5 |
|
6 |
# Fetch arXiv papers
|
7 |
def fetch_arxiv_papers(query, max_results=10):
|
@@ -12,22 +14,36 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
12 |
sort_by=arxiv.SortCriterion.SubmittedDate
|
13 |
)
|
14 |
results = list(client.results(search))
|
15 |
-
papers = [{"title": result.title, "
|
16 |
return papers
|
17 |
|
18 |
-
# Build and save FAISS index
|
19 |
-
def build_faiss_index(papers,
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# Create FAISS index
|
24 |
-
dimension =
|
25 |
index = faiss.IndexFlatL2(dimension)
|
26 |
-
index.add(
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
faiss.write_index(index, index_file)
|
30 |
-
print(f"FAISS index saved to {index_file}")
|
31 |
|
32 |
# Example usage
|
33 |
if __name__ == "__main__":
|
|
|
2 |
import faiss
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
import arxiv
|
5 |
+
from datasets import Dataset
|
6 |
+
import os
|
7 |
|
8 |
# Fetch arXiv papers
|
9 |
def fetch_arxiv_papers(query, max_results=10):
|
|
|
14 |
sort_by=arxiv.SortCriterion.SubmittedDate
|
15 |
)
|
16 |
results = list(client.results(search))
|
17 |
+
papers = [{"title": result.title, "text": result.summary, "id": str(i)} for i, result in enumerate(results)]
|
18 |
return papers
|
19 |
|
20 |
+
# Build and save dataset with FAISS index
|
21 |
+
def build_faiss_index(papers, dataset_dir="rag_dataset"):
|
22 |
+
# Create dataset
|
23 |
+
dataset = Dataset.from_dict({
|
24 |
+
"id": [p["id"] for p in papers],
|
25 |
+
"title": [p["title"] for p in papers],
|
26 |
+
"text": [p["text"] for p in papers],
|
27 |
+
})
|
28 |
|
29 |
+
# Create embeddings
|
30 |
+
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
31 |
+
embeddings = embedder.encode(dataset["text"], show_progress_bar=True)
|
32 |
+
|
33 |
+
# Add embeddings to dataset
|
34 |
+
dataset = dataset.add_column("embeddings", [emb.tolist() for emb in embeddings])
|
35 |
+
|
36 |
# Create FAISS index
|
37 |
+
dimension = embeddings.shape[1]
|
38 |
index = faiss.IndexFlatL2(dimension)
|
39 |
+
index.add(embeddings.astype(np.float32))
|
40 |
+
|
41 |
+
# Save dataset and index
|
42 |
+
os.makedirs(dataset_dir, exist_ok=True)
|
43 |
+
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
44 |
+
faiss.write_index(index, os.path.join(dataset_dir, "embeddings.faiss"))
|
45 |
|
46 |
+
return dataset_dir
|
|
|
|
|
47 |
|
48 |
# Example usage
|
49 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
@@ -3,4 +3,5 @@ transformers
|
|
3 |
datasets
|
4 |
sentence-transformers
|
5 |
faiss-cpu
|
6 |
-
arxiv
|
|
|
|
3 |
datasets
|
4 |
sentence-transformers
|
5 |
faiss-cpu
|
6 |
+
arxiv
|
7 |
+
torch
|