Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import arxiv | |
# Title | |
st.title("arXiv RAG with Streamlit") | |
# Input: Query | |
query = st.text_input("Enter your query:") | |
# Fetch arXiv papers | |
def fetch_arxiv_papers(query, max_results=5): | |
client = arxiv.Client() | |
search = arxiv.Search( | |
query=query, | |
max_results=max_results, | |
sort_by=arxiv.SortCriterion.SubmittedDate | |
) | |
results = list(client.results(search)) | |
papers = [{"title": result.title, "summary": result.summary, "pdf_url": result.pdf_url} for result in results] | |
return papers | |
# RAG Pipeline | |
def rag_pipeline(query, papers): | |
# Load pre-trained RAG model | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") | |
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="custom") | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever) | |
# Encode papers into embeddings | |
embedder = SentenceTransformer('all-MiniLM-L6-v2') | |
paper_embeddings = embedder.encode([paper["summary"] for paper in papers]) | |
# Build FAISS index | |
index = faiss.IndexFlatL2(paper_embeddings.shape[1]) | |
index.add(paper_embeddings) | |
# Retrieve relevant papers | |
query_embedding = embedder.encode([query]) | |
distances, indices = index.search(query_embedding, k=2) # Top 2 relevant papers | |
relevant_papers = [papers[i] for i in indices[0]] | |
# Generate answer using RAG | |
inputs = tokenizer(query, return_tensors="pt") | |
generated_ids = model.generate(inputs["input_ids"]) | |
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return answer, relevant_papers | |
# Run the app | |
if query: | |
st.write("Fetching arXiv papers...") | |
papers = fetch_arxiv_papers(query) | |
st.write(f"Found {len(papers)} papers.") | |
st.write("Running RAG pipeline...") | |
answer, relevant_papers = rag_pipeline(query, papers) | |
st.write("### Answer:") | |
st.write(answer) | |
st.write("### Relevant Papers:") | |
for paper in relevant_papers: | |
st.write(f"**Title:** {paper['title']}") | |
st.write(f"**Summary:** {paper['summary']}") | |
st.write(f"**PDF URL:** {paper['pdf_url']}") | |
st.write("---") |