Spaces:
Sleeping
Sleeping
File size: 3,191 Bytes
99637f2 f91cc3b 0f8445a 8108db5 99637f2 0452175 99637f2 8108db5 f68ac31 f99a008 f68ac31 f99a008 f68ac31 99637f2 0452175 8108db5 f99a008 42d1dd5 f68ac31 f99a008 f68ac31 f99a008 42d1dd5 f68ac31 42d1dd5 92c1c48 42d1dd5 92c1c48 f68ac31 92c1c48 42d1dd5 f68ac31 42d1dd5 f68ac31 42d1dd5 92c1c48 42d1dd5 f68ac31 42d1dd5 f68ac31 42d1dd5 f68ac31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
def fetch_arxiv_papers(query, max_results=10):
"""Fetch papers from arXiv and format them for RAG"""
client = arxiv.Client()
search = arxiv.Search(
query=f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", # Focus on biology and medical categories
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance # Changed to relevance-based sorting
)
results = list(client.results(search))
papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
logging.info(f"Fetched {len(papers)} papers from arXiv")
return papers
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
"""Build and save dataset with FAISS index for RAG"""
# Initialize smaller DPR encoder
ctx_encoder = DPRContextEncoder.from_pretrained(
"facebook/dpr-ctx_encoder-single-nq-base",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
# Create embeddings with smaller batches and memory optimization
texts = [p["text"] for p in papers]
embeddings = []
batch_size = 4 # Smaller batch size
with torch.inference_mode():
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
inputs = ctx_tokenizer(
batch_texts,
max_length=256, # Reduced from default
padding=True,
truncation=True,
return_tensors="pt"
)
outputs = ctx_encoder(**inputs)
embeddings.extend(outputs.pooler_output.cpu().numpy())
# Clear memory
del outputs
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Convert to numpy array and build FAISS index
embeddings = np.array(embeddings, dtype=np.float32) # Ensure float32 type
dimension = embeddings.shape[1]
# Normalize the vectors manually
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / norms
# Create FAISS index
index = faiss.IndexFlatIP(dimension)
index.add(embeddings)
# Create and save the dataset
dataset = Dataset.from_dict({
"text": texts,
"embeddings": embeddings.tolist(), # Convert to list for storage
"title": [p["title"] for p in papers]
})
# Create directory if it doesn't exist
os.makedirs(dataset_dir, exist_ok=True)
# Save dataset
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
logging.info(f"Dataset saved to {dataset_dir}")
return dataset_dir
|