Spaces:
Sleeping
Sleeping
fix: performance
Browse files- app.py +12 -9
- faiss_index/index.py +13 -8
- requirements.txt +1 -1
app.py
CHANGED
@@ -17,33 +17,36 @@ DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
|
17 |
# Cache models and dataset
|
18 |
@st.cache_resource
|
19 |
def load_models():
|
20 |
-
model_name = "t5-
|
21 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
22 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
23 |
return tokenizer, model
|
24 |
|
25 |
@st.cache_data
|
26 |
-
def load_dataset():
|
27 |
# Create initial dataset if it doesn't exist
|
28 |
if not os.path.exists(DATASET_PATH):
|
29 |
with st.spinner("Building initial dataset from autism research papers..."):
|
30 |
import faiss_index.index as idx
|
31 |
-
papers = idx.fetch_arxiv_papers("
|
32 |
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
|
33 |
|
34 |
# Load and convert to pandas for easier handling
|
35 |
dataset = load_from_disk(DATASET_PATH)
|
36 |
-
|
37 |
'title': dataset['title'],
|
38 |
'text': dataset['text']
|
39 |
})
|
|
|
40 |
|
41 |
def generate_answer(question, context, max_length=200):
|
42 |
tokenizer, model = load_models()
|
43 |
|
44 |
-
#
|
|
|
|
|
45 |
inputs = tokenizer(
|
46 |
-
|
47 |
add_special_tokens=True,
|
48 |
return_tensors="pt",
|
49 |
max_length=512,
|
@@ -72,7 +75,7 @@ query = st.text_input("Please ask me anything about autism ✨")
|
|
72 |
if query:
|
73 |
with st.status("Searching for answers..."):
|
74 |
# Load dataset
|
75 |
-
df = load_dataset()
|
76 |
|
77 |
# Get relevant context
|
78 |
context = "\n".join([
|
|
|
17 |
# Cache models and dataset
|
18 |
@st.cache_resource
|
19 |
def load_models():
|
20 |
+
model_name = "google/flan-t5-small" # Lighter model
|
21 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
22 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
|
23 |
return tokenizer, model
|
24 |
|
25 |
@st.cache_data
|
26 |
+
def load_dataset(query):
|
27 |
# Create initial dataset if it doesn't exist
|
28 |
if not os.path.exists(DATASET_PATH):
|
29 |
with st.spinner("Building initial dataset from autism research papers..."):
|
30 |
import faiss_index.index as idx
|
31 |
+
papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=50) # More focused search
|
32 |
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
|
33 |
|
34 |
# Load and convert to pandas for easier handling
|
35 |
dataset = load_from_disk(DATASET_PATH)
|
36 |
+
df = pd.DataFrame({
|
37 |
'title': dataset['title'],
|
38 |
'text': dataset['text']
|
39 |
})
|
40 |
+
return df
|
41 |
|
42 |
def generate_answer(question, context, max_length=200):
|
43 |
tokenizer, model = load_models()
|
44 |
|
45 |
+
# Add context about medical information
|
46 |
+
prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
|
47 |
+
|
48 |
inputs = tokenizer(
|
49 |
+
prompt,
|
50 |
add_special_tokens=True,
|
51 |
return_tensors="pt",
|
52 |
max_length=512,
|
|
|
75 |
if query:
|
76 |
with st.status("Searching for answers..."):
|
77 |
# Load dataset
|
78 |
+
df = load_dataset(query)
|
79 |
|
80 |
# Get relevant context
|
81 |
context = "\n".join([
|
faiss_index/index.py
CHANGED
@@ -18,9 +18,9 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
18 |
"""Fetch papers from arXiv and format them for RAG"""
|
19 |
client = arxiv.Client()
|
20 |
search = arxiv.Search(
|
21 |
-
query=query,
|
22 |
max_results=max_results,
|
23 |
-
sort_by=arxiv.SortCriterion.
|
24 |
)
|
25 |
results = list(client.results(search))
|
26 |
papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
|
@@ -29,21 +29,24 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
29 |
|
30 |
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
31 |
"""Build and save dataset with FAISS index for RAG"""
|
32 |
-
# Initialize DPR encoder
|
33 |
-
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
34 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
35 |
|
36 |
-
# Create embeddings
|
37 |
texts = [p["text"] for p in papers]
|
38 |
embeddings = []
|
39 |
-
batch_size =
|
|
|
40 |
for i in range(0, len(texts), batch_size):
|
41 |
batch = texts[i:i + batch_size]
|
42 |
-
inputs = ctx_tokenizer(batch, max_length=
|
43 |
with torch.no_grad():
|
44 |
outputs = ctx_encoder(**inputs)
|
45 |
batch_embeddings = outputs.pooler_output.cpu().numpy()
|
46 |
embeddings.append(batch_embeddings)
|
|
|
|
|
47 |
|
48 |
embeddings = np.vstack(embeddings)
|
49 |
logging.info(f"Created embeddings with shape {embeddings.shape}")
|
@@ -58,7 +61,9 @@ def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
|
58 |
|
59 |
# Create FAISS index
|
60 |
dimension = embeddings.shape[1]
|
61 |
-
|
|
|
|
|
62 |
index.add(embeddings.astype(np.float32))
|
63 |
|
64 |
# Save dataset and index
|
|
|
18 |
"""Fetch papers from arXiv and format them for RAG"""
|
19 |
client = arxiv.Client()
|
20 |
search = arxiv.Search(
|
21 |
+
query=f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", # Focus on biology and medical categories
|
22 |
max_results=max_results,
|
23 |
+
sort_by=arxiv.SortCriterion.Relevance # Changed to relevance-based sorting
|
24 |
)
|
25 |
results = list(client.results(search))
|
26 |
papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
|
|
|
29 |
|
30 |
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
31 |
"""Build and save dataset with FAISS index for RAG"""
|
32 |
+
# Initialize smaller DPR encoder
|
33 |
+
ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", device_map="auto", load_in_8bit=True)
|
34 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
35 |
|
36 |
+
# Create embeddings with smaller batches and memory optimization
|
37 |
texts = [p["text"] for p in papers]
|
38 |
embeddings = []
|
39 |
+
batch_size = 4 # Smaller batch size
|
40 |
+
|
41 |
for i in range(0, len(texts), batch_size):
|
42 |
batch = texts[i:i + batch_size]
|
43 |
+
inputs = ctx_tokenizer(batch, max_length=256, padding=True, truncation=True, return_tensors="pt") # Reduced max_length
|
44 |
with torch.no_grad():
|
45 |
outputs = ctx_encoder(**inputs)
|
46 |
batch_embeddings = outputs.pooler_output.cpu().numpy()
|
47 |
embeddings.append(batch_embeddings)
|
48 |
+
del outputs # Explicit cleanup
|
49 |
+
torch.cuda.empty_cache() # Clear GPU memory
|
50 |
|
51 |
embeddings = np.vstack(embeddings)
|
52 |
logging.info(f"Created embeddings with shape {embeddings.shape}")
|
|
|
61 |
|
62 |
# Create FAISS index
|
63 |
dimension = embeddings.shape[1]
|
64 |
+
quantizer = faiss.IndexFlatL2(dimension)
|
65 |
+
index = faiss.IndexQuantizer(dimension, quantizer, 8)
|
66 |
+
index.train(embeddings.astype(np.float32))
|
67 |
index.add(embeddings.astype(np.float32))
|
68 |
|
69 |
# Save dataset and index
|
requirements.txt
CHANGED
@@ -4,5 +4,5 @@ datasets
|
|
4 |
sentence-transformers
|
5 |
faiss-cpu
|
6 |
arxiv
|
7 |
-
torch
|
8 |
accelerate>=0.26.0
|
|
|
4 |
sentence-transformers
|
5 |
faiss-cpu
|
6 |
arxiv
|
7 |
+
torch --index-url https://download.pytorch.org/whl/cpu
|
8 |
accelerate>=0.26.0
|