wakeupmh commited on
Commit
f99a008
·
1 Parent(s): 8903db2

fix: performance

Browse files
Files changed (3) hide show
  1. app.py +12 -9
  2. faiss_index/index.py +13 -8
  3. requirements.txt +1 -1
app.py CHANGED
@@ -17,33 +17,36 @@ DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
17
  # Cache models and dataset
18
  @st.cache_resource
19
  def load_models():
20
- model_name = "t5-base"
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
23
  return tokenizer, model
24
 
25
  @st.cache_data
26
- def load_dataset():
27
  # Create initial dataset if it doesn't exist
28
  if not os.path.exists(DATASET_PATH):
29
  with st.spinner("Building initial dataset from autism research papers..."):
30
  import faiss_index.index as idx
31
- papers = idx.fetch_arxiv_papers("autism research", max_results=100)
32
  idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
33
 
34
  # Load and convert to pandas for easier handling
35
  dataset = load_from_disk(DATASET_PATH)
36
- return pd.DataFrame({
37
  'title': dataset['title'],
38
  'text': dataset['text']
39
  })
 
40
 
41
  def generate_answer(question, context, max_length=200):
42
  tokenizer, model = load_models()
43
 
44
- # Encode the question and context
 
 
45
  inputs = tokenizer(
46
- f"question: {question} context: {context}",
47
  add_special_tokens=True,
48
  return_tensors="pt",
49
  max_length=512,
@@ -72,7 +75,7 @@ query = st.text_input("Please ask me anything about autism ✨")
72
  if query:
73
  with st.status("Searching for answers..."):
74
  # Load dataset
75
- df = load_dataset()
76
 
77
  # Get relevant context
78
  context = "\n".join([
 
17
  # Cache models and dataset
18
  @st.cache_resource
19
  def load_models():
20
+ model_name = "google/flan-t5-small" # Lighter model
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, device_map="auto", load_in_8bit=True)
23
  return tokenizer, model
24
 
25
  @st.cache_data
26
+ def load_dataset(query):
27
  # Create initial dataset if it doesn't exist
28
  if not os.path.exists(DATASET_PATH):
29
  with st.spinner("Building initial dataset from autism research papers..."):
30
  import faiss_index.index as idx
31
+ papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=50) # More focused search
32
  idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
33
 
34
  # Load and convert to pandas for easier handling
35
  dataset = load_from_disk(DATASET_PATH)
36
+ df = pd.DataFrame({
37
  'title': dataset['title'],
38
  'text': dataset['text']
39
  })
40
+ return df
41
 
42
  def generate_answer(question, context, max_length=200):
43
  tokenizer, model = load_models()
44
 
45
+ # Add context about medical information
46
+ prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
47
+
48
  inputs = tokenizer(
49
+ prompt,
50
  add_special_tokens=True,
51
  return_tensors="pt",
52
  max_length=512,
 
75
  if query:
76
  with st.status("Searching for answers..."):
77
  # Load dataset
78
+ df = load_dataset(query)
79
 
80
  # Get relevant context
81
  context = "\n".join([
faiss_index/index.py CHANGED
@@ -18,9 +18,9 @@ def fetch_arxiv_papers(query, max_results=10):
18
  """Fetch papers from arXiv and format them for RAG"""
19
  client = arxiv.Client()
20
  search = arxiv.Search(
21
- query=query,
22
  max_results=max_results,
23
- sort_by=arxiv.SortCriterion.SubmittedDate
24
  )
25
  results = list(client.results(search))
26
  papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
@@ -29,21 +29,24 @@ def fetch_arxiv_papers(query, max_results=10):
29
 
30
  def build_faiss_index(papers, dataset_dir=DATASET_DIR):
31
  """Build and save dataset with FAISS index for RAG"""
32
- # Initialize DPR encoder
33
- ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
34
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
35
 
36
- # Create embeddings
37
  texts = [p["text"] for p in papers]
38
  embeddings = []
39
- batch_size = 8
 
40
  for i in range(0, len(texts), batch_size):
41
  batch = texts[i:i + batch_size]
42
- inputs = ctx_tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors="pt")
43
  with torch.no_grad():
44
  outputs = ctx_encoder(**inputs)
45
  batch_embeddings = outputs.pooler_output.cpu().numpy()
46
  embeddings.append(batch_embeddings)
 
 
47
 
48
  embeddings = np.vstack(embeddings)
49
  logging.info(f"Created embeddings with shape {embeddings.shape}")
@@ -58,7 +61,9 @@ def build_faiss_index(papers, dataset_dir=DATASET_DIR):
58
 
59
  # Create FAISS index
60
  dimension = embeddings.shape[1]
61
- index = faiss.IndexFlatL2(dimension)
 
 
62
  index.add(embeddings.astype(np.float32))
63
 
64
  # Save dataset and index
 
18
  """Fetch papers from arXiv and format them for RAG"""
19
  client = arxiv.Client()
20
  search = arxiv.Search(
21
+ query=f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", # Focus on biology and medical categories
22
  max_results=max_results,
23
+ sort_by=arxiv.SortCriterion.Relevance # Changed to relevance-based sorting
24
  )
25
  results = list(client.results(search))
26
  papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
 
29
 
30
  def build_faiss_index(papers, dataset_dir=DATASET_DIR):
31
  """Build and save dataset with FAISS index for RAG"""
32
+ # Initialize smaller DPR encoder
33
+ ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base", device_map="auto", load_in_8bit=True)
34
  ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
35
 
36
+ # Create embeddings with smaller batches and memory optimization
37
  texts = [p["text"] for p in papers]
38
  embeddings = []
39
+ batch_size = 4 # Smaller batch size
40
+
41
  for i in range(0, len(texts), batch_size):
42
  batch = texts[i:i + batch_size]
43
+ inputs = ctx_tokenizer(batch, max_length=256, padding=True, truncation=True, return_tensors="pt") # Reduced max_length
44
  with torch.no_grad():
45
  outputs = ctx_encoder(**inputs)
46
  batch_embeddings = outputs.pooler_output.cpu().numpy()
47
  embeddings.append(batch_embeddings)
48
+ del outputs # Explicit cleanup
49
+ torch.cuda.empty_cache() # Clear GPU memory
50
 
51
  embeddings = np.vstack(embeddings)
52
  logging.info(f"Created embeddings with shape {embeddings.shape}")
 
61
 
62
  # Create FAISS index
63
  dimension = embeddings.shape[1]
64
+ quantizer = faiss.IndexFlatL2(dimension)
65
+ index = faiss.IndexQuantizer(dimension, quantizer, 8)
66
+ index.train(embeddings.astype(np.float32))
67
  index.add(embeddings.astype(np.float32))
68
 
69
  # Save dataset and index
requirements.txt CHANGED
@@ -4,5 +4,5 @@ datasets
4
  sentence-transformers
5
  faiss-cpu
6
  arxiv
7
- torch
8
  accelerate>=0.26.0
 
4
  sentence-transformers
5
  faiss-cpu
6
  arxiv
7
+ torch --index-url https://download.pytorch.org/whl/cpu
8
  accelerate>=0.26.0