Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import arxiv | |
# Title | |
st.title("arXiv RAG with Streamlit") | |
# Input: Query | |
query = st.text_input("Enter your query:") | |
# Fetch arXiv papers | |
def fetch_arxiv_papers(query, max_results=5): | |
client = arxiv.Client() | |
search = arxiv.Search( | |
query=query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.SubmittedDate | |
) | |
results = list(client.results(search)) | |
papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results] | |
return papers | |
# Load FAISS index | |
def load_faiss_index(index_file="faiss_index.index"): | |
import os | |
if not os.path.exists(index_file): | |
# Import the build function from the other file | |
import faiss_index.index as faiss_index_index | |
# Fetch some initial papers to build the index | |
initial_papers = faiss_index_index.fetch_arxiv_papers("autism research", max_results=100) | |
faiss_index_index.build_faiss_index(initial_papers, index_file) | |
return faiss.read_index(index_file) | |
# RAG Pipeline | |
def rag_pipeline(query, papers, index): | |
# Load pre-trained RAG model | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom", passages=papers, index=index) | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) | |
# Generate answer using RAG | |
inputs = tokenizer(query, return_tensors="pt") | |
generated_ids = model.generate(inputs["input_ids"]) | |
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return answer | |
# Run the app | |
if query: | |
st.write("Fetching arXiv papers...") | |
papers = fetch_arxiv_papers(query) | |
st.write(f"Found {len(papers)} papers.") | |
st.write("Loading FAISS index...") | |
index = load_faiss_index() | |
st.write("Running RAG pipeline...") | |
answer = rag_pipeline(query, papers, index) | |
st.write("### Answer:") | |
st.write(answer) | |
st.write("### Relevant Papers:") | |
for paper in papers: | |
st.write(f"**Title:** {paper['title']}") | |
st.write(f"**Summary:** {paper['summary']}") | |
st.write(f"**PDF URL:** {paper['pdf_url']}") | |
st.write("---") |