Spaces:
Sleeping
Sleeping
fix: improve mem usage
Browse files- app.py +20 -19
- faiss_index/index.py +40 -30
- requirements.txt +10 -8
app.py
CHANGED
@@ -22,17 +22,19 @@ def load_models():
|
|
22 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
23 |
model_name,
|
24 |
torch_dtype=torch.float16,
|
25 |
-
low_cpu_mem_usage=True
|
|
|
|
|
26 |
)
|
27 |
return tokenizer, model
|
28 |
|
29 |
-
@st.cache_data
|
30 |
def load_dataset(query):
|
31 |
# Create initial dataset if it doesn't exist
|
32 |
if not os.path.exists(DATASET_PATH):
|
33 |
with st.spinner("Building initial dataset from autism research papers..."):
|
34 |
import faiss_index.index as idx
|
35 |
-
papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=
|
36 |
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
|
37 |
|
38 |
# Load and convert to pandas for easier handling
|
@@ -43,32 +45,31 @@ def load_dataset(query):
|
|
43 |
})
|
44 |
return df
|
45 |
|
46 |
-
def generate_answer(question, context, max_length=
|
47 |
tokenizer, model = load_models()
|
48 |
|
49 |
# Add context about medical information
|
50 |
prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
add_special_tokens=True,
|
55 |
-
return_tensors="pt",
|
56 |
-
max_length=512,
|
57 |
-
truncation=True,
|
58 |
-
padding=True
|
59 |
-
)
|
60 |
|
61 |
-
#
|
62 |
-
with torch.no_grad():
|
63 |
outputs = model.generate(
|
64 |
-
inputs
|
65 |
max_length=max_length,
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
69 |
early_stopping=True
|
70 |
)
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
|
74 |
|
|
|
22 |
model = AutoModelForSeq2SeqLM.from_pretrained(
|
23 |
model_name,
|
24 |
torch_dtype=torch.float16,
|
25 |
+
low_cpu_mem_usage=True,
|
26 |
+
device_map='auto',
|
27 |
+
max_memory={'cpu': '1GB'}
|
28 |
)
|
29 |
return tokenizer, model
|
30 |
|
31 |
+
@st.cache_data(ttl=3600) # Cache for 1 hour
|
32 |
def load_dataset(query):
|
33 |
# Create initial dataset if it doesn't exist
|
34 |
if not os.path.exists(DATASET_PATH):
|
35 |
with st.spinner("Building initial dataset from autism research papers..."):
|
36 |
import faiss_index.index as idx
|
37 |
+
papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=25) # Reduced max results
|
38 |
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
|
39 |
|
40 |
# Load and convert to pandas for easier handling
|
|
|
45 |
})
|
46 |
return df
|
47 |
|
48 |
+
def generate_answer(question, context, max_length=150): # Reduced max length
|
49 |
tokenizer, model = load_models()
|
50 |
|
51 |
# Add context about medical information
|
52 |
prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
|
53 |
|
54 |
+
# Optimize input processing
|
55 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
|
57 |
+
with torch.inference_mode(): # More efficient than no_grad
|
|
|
58 |
outputs = model.generate(
|
59 |
+
**inputs,
|
60 |
max_length=max_length,
|
61 |
+
num_beams=2, # Reduced beam search
|
62 |
+
temperature=0.7,
|
63 |
+
top_p=0.9,
|
64 |
+
repetition_penalty=1.2,
|
65 |
early_stopping=True
|
66 |
)
|
67 |
+
|
68 |
+
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
69 |
+
|
70 |
+
# Clear GPU memory if possible
|
71 |
+
if torch.cuda.is_available():
|
72 |
+
torch.cuda.empty_cache()
|
73 |
|
74 |
return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
|
75 |
|
faiss_index/index.py
CHANGED
@@ -30,7 +30,11 @@ def fetch_arxiv_papers(query, max_results=10):
|
|
30 |
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
31 |
"""Build and save dataset with FAISS index for RAG"""
|
32 |
# Initialize smaller DPR encoder
|
33 |
-
ctx_encoder = DPRContextEncoder.from_pretrained(
|
|
|
|
|
|
|
|
|
34 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
35 |
|
36 |
# Create embeddings with smaller batches and memory optimization
|
@@ -38,40 +42,46 @@ def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
|
38 |
embeddings = []
|
39 |
batch_size = 4 # Smaller batch size
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
outputs = ctx_encoder(**inputs)
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
|
|
53 |
|
54 |
-
# Create dataset
|
55 |
dataset = Dataset.from_dict({
|
56 |
-
"
|
57 |
-
"
|
58 |
-
"title": [p["title"] for p in papers]
|
59 |
-
"embeddings": [emb.tolist() for emb in embeddings],
|
60 |
})
|
61 |
|
62 |
-
# Create
|
63 |
-
dimension = embeddings.shape[1]
|
64 |
-
quantizer = faiss.IndexFlatL2(dimension)
|
65 |
-
index = faiss.IndexQuantizer(dimension, quantizer, 8)
|
66 |
-
index.train(embeddings.astype(np.float32))
|
67 |
-
index.add(embeddings.astype(np.float32))
|
68 |
-
|
69 |
-
# Save dataset and index
|
70 |
os.makedirs(dataset_dir, exist_ok=True)
|
71 |
-
|
72 |
-
|
73 |
-
dataset.save_to_disk(
|
74 |
-
|
75 |
-
logging.info(f"Saved dataset to {dataset_path}")
|
76 |
-
logging.info(f"Saved index to {index_path}")
|
77 |
return dataset_dir
|
|
|
30 |
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
31 |
"""Build and save dataset with FAISS index for RAG"""
|
32 |
# Initialize smaller DPR encoder
|
33 |
+
ctx_encoder = DPRContextEncoder.from_pretrained(
|
34 |
+
"facebook/dpr-ctx_encoder-single-nq-base",
|
35 |
+
torch_dtype=torch.float16,
|
36 |
+
low_cpu_mem_usage=True
|
37 |
+
)
|
38 |
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
39 |
|
40 |
# Create embeddings with smaller batches and memory optimization
|
|
|
42 |
embeddings = []
|
43 |
batch_size = 4 # Smaller batch size
|
44 |
|
45 |
+
with torch.inference_mode():
|
46 |
+
for i in range(0, len(texts), batch_size):
|
47 |
+
batch_texts = texts[i:i + batch_size]
|
48 |
+
inputs = ctx_tokenizer(
|
49 |
+
batch_texts,
|
50 |
+
max_length=256, # Reduced from default
|
51 |
+
padding=True,
|
52 |
+
truncation=True,
|
53 |
+
return_tensors="pt"
|
54 |
+
)
|
55 |
outputs = ctx_encoder(**inputs)
|
56 |
+
embeddings.extend(outputs.pooler_output.cpu().numpy())
|
57 |
+
|
58 |
+
# Clear memory
|
59 |
+
del outputs
|
60 |
+
if torch.cuda.is_available():
|
61 |
+
torch.cuda.empty_cache()
|
62 |
+
|
63 |
+
# Convert to numpy array and build FAISS index
|
64 |
+
embeddings = np.array(embeddings)
|
65 |
+
dimension = embeddings.shape[1]
|
66 |
+
|
67 |
+
# Use more efficient index type
|
68 |
+
index = faiss.IndexFlatIP(dimension) # Simple but efficient dot-product index
|
69 |
|
70 |
+
# Normalize vectors to use dot product as similarity
|
71 |
+
faiss.normalize_L2(embeddings)
|
72 |
+
index.add(embeddings)
|
73 |
|
74 |
+
# Create and save the dataset
|
75 |
dataset = Dataset.from_dict({
|
76 |
+
"text": texts,
|
77 |
+
"embeddings": embeddings,
|
78 |
+
"title": [p["title"] for p in papers]
|
|
|
79 |
})
|
80 |
|
81 |
+
# Create directory if it doesn't exist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
os.makedirs(dataset_dir, exist_ok=True)
|
83 |
+
|
84 |
+
# Save dataset
|
85 |
+
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
86 |
+
logging.info(f"Dataset saved to {dataset_dir}")
|
|
|
|
|
87 |
return dataset_dir
|
requirements.txt
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
-
streamlit
|
2 |
-
transformers
|
3 |
-
datasets
|
4 |
-
sentence-transformers
|
5 |
-
faiss-cpu
|
6 |
-
arxiv
|
7 |
--extra-index-url https://download.pytorch.org/whl/cpu
|
8 |
-
torch
|
9 |
accelerate>=0.26.0
|
10 |
-
bitsandbytes>=0.41.1
|
|
|
|
|
|
1 |
+
streamlit>=1.32.0
|
2 |
+
transformers>=4.37.0
|
3 |
+
datasets>=2.17.0
|
4 |
+
sentence-transformers>=2.3.1
|
5 |
+
faiss-cpu>=1.7.4
|
6 |
+
arxiv>=2.1.0
|
7 |
--extra-index-url https://download.pytorch.org/whl/cpu
|
8 |
+
torch>=2.2.0
|
9 |
accelerate>=0.26.0
|
10 |
+
bitsandbytes>=0.41.1
|
11 |
+
numpy>=1.24.0
|
12 |
+
pandas>=2.2.0
|