Spaces:
Sleeping
Sleeping
File size: 3,179 Bytes
99637f2 f91cc3b 0f8445a 8108db5 99637f2 0452175 99637f2 8108db5 f68ac31 f99a008 f68ac31 f99a008 f68ac31 99637f2 0452175 8108db5 f99a008 f68ac31 f99a008 f68ac31 f99a008 f68ac31 f99a008 f68ac31 f99a008 f68ac31 f99a008 f68ac31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import numpy as np
import faiss
import arxiv
from datasets import Dataset
import os
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
# Define data paths
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
def fetch_arxiv_papers(query, max_results=10):
"""Fetch papers from arXiv and format them for RAG"""
client = arxiv.Client()
search = arxiv.Search(
query=f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", # Focus on biology and medical categories
max_results=max_results,
sort_by=arxiv.SortCriterion.Relevance # Changed to relevance-based sorting
)
results = list(client.results(search))
papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
logging.info(f"Fetched {len(papers)} papers from arXiv")
return papers
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
"""Build and save dataset with FAISS index for RAG"""
# Initialize smaller DPR encoder
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", device_map="auto", load_in_8bit=True)
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
# Create embeddings with smaller batches and memory optimization
texts = [p["text"] for p in papers]
embeddings = []
batch_size = 4 # Smaller batch size
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
inputs = ctx_tokenizer(batch, max_length=256, padding=True, truncation=True, return_tensors="pt") # Reduced max_length
with torch.no_grad():
outputs = ctx_encoder(**inputs)
batch_embeddings = outputs.pooler_output.cpu().numpy()
embeddings.append(batch_embeddings)
del outputs # Explicit cleanup
torch.cuda.empty_cache() # Clear GPU memory
embeddings = np.vstack(embeddings)
logging.info(f"Created embeddings with shape {embeddings.shape}")
# Create dataset
dataset = Dataset.from_dict({
"id": [p["id"] for p in papers],
"text": [p["text"] for p in papers],
"title": [p["title"] for p in papers],
"embeddings": [emb.tolist() for emb in embeddings],
})
# Create FAISS index
dimension = embeddings.shape[1]
quantizer = faiss.IndexFlatL2(dimension)
index = faiss.IndexQuantizer(dimension, quantizer, 8)
index.train(embeddings.astype(np.float32))
index.add(embeddings.astype(np.float32))
# Save dataset and index
os.makedirs(dataset_dir, exist_ok=True)
dataset_path = os.path.join(dataset_dir, "dataset")
index_path = os.path.join(dataset_dir, "embeddings.faiss")
dataset.save_to_disk(dataset_path)
faiss.write_index(index, index_path)
logging.info(f"Saved dataset to {dataset_path}")
logging.info(f"Saved index to {index_path}")
return dataset_dir
|