wakeupmh's picture
fix: run in hf
0452175
raw
history blame
2.66 kB
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
def fetch_arxiv_papers(query, max_results=10):
"""Fetch papers from arXiv and format them for RAG"""
client = arxiv.Client()
search = arxiv.Search(
query=query,
max_results=max_results,
sort_by=arxiv.SortCriterion.SubmittedDate
)
results = list(client.results(search))
papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
logging.info(f"Fetched {len(papers)} papers from arXiv")
return papers
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
"""Build and save dataset with FAISS index for RAG"""
# Initialize DPR encoder
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
# Create embeddings
texts = [p["text"] for p in papers]
embeddings = []
batch_size = 8
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
outputs = ctx_encoder(**inputs)
batch_embeddings = outputs.pooler_output.cpu().numpy()
embeddings.append(batch_embeddings)
embeddings = np.vstack(embeddings)
logging.info(f"Created embeddings with shape {embeddings.shape}")
# Create dataset
dataset = Dataset.from_dict({
"id": [p["id"] for p in papers],
"text": [p["text"] for p in papers],
"title": [p["title"] for p in papers],
"embeddings": [emb.tolist() for emb in embeddings],
})
# Create FAISS index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings.astype(np.float32))
# Save dataset and index
os.makedirs(dataset_dir, exist_ok=True)
dataset_path = os.path.join(dataset_dir, "dataset")
index_path = os.path.join(dataset_dir, "embeddings.faiss")
dataset.save_to_disk(dataset_path)
faiss.write_index(index, index_path)
logging.info(f"Saved dataset to {dataset_path}")
logging.info(f"Saved index to {index_path}")
return dataset_dir