wakeupmh commited on
Commit
42d1dd5
·
1 Parent(s): 8bb473c

fix: improve mem usage

Browse files
Files changed (3) hide show
  1. app.py +20 -19
  2. faiss_index/index.py +40 -30
  3. requirements.txt +10 -8
app.py CHANGED
@@ -22,17 +22,19 @@ def load_models():
22
  model = AutoModelForSeq2SeqLM.from_pretrained(
23
  model_name,
24
  torch_dtype=torch.float16,
25
- low_cpu_mem_usage=True
 
 
26
  )
27
  return tokenizer, model
28
 
29
- @st.cache_data
30
  def load_dataset(query):
31
  # Create initial dataset if it doesn't exist
32
  if not os.path.exists(DATASET_PATH):
33
  with st.spinner("Building initial dataset from autism research papers..."):
34
  import faiss_index.index as idx
35
- papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=50) # More focused search
36
  idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
37
 
38
  # Load and convert to pandas for easier handling
@@ -43,32 +45,31 @@ def load_dataset(query):
43
  })
44
  return df
45
 
46
- def generate_answer(question, context, max_length=200):
47
  tokenizer, model = load_models()
48
 
49
  # Add context about medical information
50
  prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
51
 
52
- inputs = tokenizer(
53
- prompt,
54
- add_special_tokens=True,
55
- return_tensors="pt",
56
- max_length=512,
57
- truncation=True,
58
- padding=True
59
- )
60
 
61
- # Get model predictions
62
- with torch.no_grad():
63
  outputs = model.generate(
64
- inputs["input_ids"],
65
  max_length=max_length,
66
- min_length=30,
67
- num_beams=4,
68
- length_penalty=2.0,
 
69
  early_stopping=True
70
  )
71
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
72
 
73
  return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
74
 
 
22
  model = AutoModelForSeq2SeqLM.from_pretrained(
23
  model_name,
24
  torch_dtype=torch.float16,
25
+ low_cpu_mem_usage=True,
26
+ device_map='auto',
27
+ max_memory={'cpu': '1GB'}
28
  )
29
  return tokenizer, model
30
 
31
+ @st.cache_data(ttl=3600) # Cache for 1 hour
32
  def load_dataset(query):
33
  # Create initial dataset if it doesn't exist
34
  if not os.path.exists(DATASET_PATH):
35
  with st.spinner("Building initial dataset from autism research papers..."):
36
  import faiss_index.index as idx
37
+ papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=25) # Reduced max results
38
  idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
39
 
40
  # Load and convert to pandas for easier handling
 
45
  })
46
  return df
47
 
48
+ def generate_answer(question, context, max_length=150): # Reduced max length
49
  tokenizer, model = load_models()
50
 
51
  # Add context about medical information
52
  prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
53
 
54
+ # Optimize input processing
55
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
 
 
 
 
 
 
56
 
57
+ with torch.inference_mode(): # More efficient than no_grad
 
58
  outputs = model.generate(
59
+ **inputs,
60
  max_length=max_length,
61
+ num_beams=2, # Reduced beam search
62
+ temperature=0.7,
63
+ top_p=0.9,
64
+ repetition_penalty=1.2,
65
  early_stopping=True
66
  )
67
+
68
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
69
+
70
+ # Clear GPU memory if possible
71
+ if torch.cuda.is_available():
72
+ torch.cuda.empty_cache()
73
 
74
  return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
75
 
faiss_index/index.py CHANGED
@@ -30,7 +30,11 @@ def fetch_arxiv_papers(query, max_results=10):
30
  def build_faiss_index(papers, dataset_dir=DATASET_DIR):
31
  """Build and save dataset with FAISS index for RAG"""
32
  # Initialize smaller DPR encoder
33
- ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", device_map="auto", load_in_8bit=True)
 
 
 
 
34
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
35
 
36
  # Create embeddings with smaller batches and memory optimization
@@ -38,40 +42,46 @@ def build_faiss_index(papers, dataset_dir=DATASET_DIR):
38
  embeddings = []
39
  batch_size = 4 # Smaller batch size
40
 
41
- for i in range(0, len(texts), batch_size):
42
- batch = texts[i:i + batch_size]
43
- inputs = ctx_tokenizer(batch, max_length=256, padding=True, truncation=True, return_tensors="pt") # Reduced max_length
44
- with torch.no_grad():
 
 
 
 
 
 
45
  outputs = ctx_encoder(**inputs)
46
- batch_embeddings = outputs.pooler_output.cpu().numpy()
47
- embeddings.append(batch_embeddings)
48
- del outputs # Explicit cleanup
49
- torch.cuda.empty_cache() # Clear GPU memory
 
 
 
 
 
 
 
 
 
50
 
51
- embeddings = np.vstack(embeddings)
52
- logging.info(f"Created embeddings with shape {embeddings.shape}")
 
53
 
54
- # Create dataset
55
  dataset = Dataset.from_dict({
56
- "id": [p["id"] for p in papers],
57
- "text": [p["text"] for p in papers],
58
- "title": [p["title"] for p in papers],
59
- "embeddings": [emb.tolist() for emb in embeddings],
60
  })
61
 
62
- # Create FAISS index
63
- dimension = embeddings.shape[1]
64
- quantizer = faiss.IndexFlatL2(dimension)
65
- index = faiss.IndexQuantizer(dimension, quantizer, 8)
66
- index.train(embeddings.astype(np.float32))
67
- index.add(embeddings.astype(np.float32))
68
-
69
- # Save dataset and index
70
  os.makedirs(dataset_dir, exist_ok=True)
71
- dataset_path = os.path.join(dataset_dir, "dataset")
72
- index_path = os.path.join(dataset_dir, "embeddings.faiss")
73
- dataset.save_to_disk(dataset_path)
74
- faiss.write_index(index, index_path)
75
- logging.info(f"Saved dataset to {dataset_path}")
76
- logging.info(f"Saved index to {index_path}")
77
  return dataset_dir
 
30
  def build_faiss_index(papers, dataset_dir=DATASET_DIR):
31
  """Build and save dataset with FAISS index for RAG"""
32
  # Initialize smaller DPR encoder
33
+ ctx_encoder = DPRContextEncoder.from_pretrained(
34
+ "facebook/dpr-ctx_encoder-single-nq-base",
35
+ torch_dtype=torch.float16,
36
+ low_cpu_mem_usage=True
37
+ )
38
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
39
 
40
  # Create embeddings with smaller batches and memory optimization
 
42
  embeddings = []
43
  batch_size = 4 # Smaller batch size
44
 
45
+ with torch.inference_mode():
46
+ for i in range(0, len(texts), batch_size):
47
+ batch_texts = texts[i:i + batch_size]
48
+ inputs = ctx_tokenizer(
49
+ batch_texts,
50
+ max_length=256, # Reduced from default
51
+ padding=True,
52
+ truncation=True,
53
+ return_tensors="pt"
54
+ )
55
  outputs = ctx_encoder(**inputs)
56
+ embeddings.extend(outputs.pooler_output.cpu().numpy())
57
+
58
+ # Clear memory
59
+ del outputs
60
+ if torch.cuda.is_available():
61
+ torch.cuda.empty_cache()
62
+
63
+ # Convert to numpy array and build FAISS index
64
+ embeddings = np.array(embeddings)
65
+ dimension = embeddings.shape[1]
66
+
67
+ # Use more efficient index type
68
+ index = faiss.IndexFlatIP(dimension) # Simple but efficient dot-product index
69
 
70
+ # Normalize vectors to use dot product as similarity
71
+ faiss.normalize_L2(embeddings)
72
+ index.add(embeddings)
73
 
74
+ # Create and save the dataset
75
  dataset = Dataset.from_dict({
76
+ "text": texts,
77
+ "embeddings": embeddings,
78
+ "title": [p["title"] for p in papers]
 
79
  })
80
 
81
+ # Create directory if it doesn't exist
 
 
 
 
 
 
 
82
  os.makedirs(dataset_dir, exist_ok=True)
83
+
84
+ # Save dataset
85
+ dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
86
+ logging.info(f"Dataset saved to {dataset_dir}")
 
 
87
  return dataset_dir
requirements.txt CHANGED
@@ -1,10 +1,12 @@
1
- streamlit
2
- transformers
3
- datasets
4
- sentence-transformers
5
- faiss-cpu
6
- arxiv
7
  --extra-index-url https://download.pytorch.org/whl/cpu
8
- torch
9
  accelerate>=0.26.0
10
- bitsandbytes>=0.41.1
 
 
 
1
+ streamlit>=1.32.0
2
+ transformers>=4.37.0
3
+ datasets>=2.17.0
4
+ sentence-transformers>=2.3.1
5
+ faiss-cpu>=1.7.4
6
+ arxiv>=2.1.0
7
  --extra-index-url https://download.pytorch.org/whl/cpu
8
+ torch>=2.2.0
9
  accelerate>=0.26.0
10
+ bitsandbytes>=0.41.1
11
+ numpy>=1.24.0
12
+ pandas>=2.2.0