wakeupmh commited on
Commit
cc41495
·
1 Parent(s): 54a5022

fix: search

Browse files
Files changed (2) hide show
  1. app.py +7 -2
  2. faiss_index/index.py +30 -21
app.py CHANGED
@@ -33,8 +33,12 @@ def load_dataset(query):
33
  # Always fetch fresh results for the specific query
34
  with st.spinner("Searching autism research papers..."):
35
  import faiss_index.index as idx
36
- # Make the query more specific to autism and b12
37
- search_query = f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)"
 
 
 
 
38
  papers = idx.fetch_arxiv_papers(search_query, max_results=25)
39
  if not papers:
40
  st.warning("No relevant papers found. Please try rephrasing your question.")
@@ -88,6 +92,7 @@ def generate_answer(question, context, max_length=150):
88
 
89
  # Streamlit App
90
  st.title("🧩 AMA Autism")
 
91
  query = st.text_input("Please ask me anything about autism ✨")
92
 
93
  if query:
 
33
  # Always fetch fresh results for the specific query
34
  with st.spinner("Searching autism research papers..."):
35
  import faiss_index.index as idx
36
+ # Ensure both autism and the query terms are included
37
+ if 'autism' not in query.lower():
38
+ search_query = f"autism {query}"
39
+ else:
40
+ search_query = query
41
+
42
  papers = idx.fetch_arxiv_papers(search_query, max_results=25)
43
  if not papers:
44
  st.warning("No relevant papers found. Please try rephrasing your question.")
 
92
 
93
  # Streamlit App
94
  st.title("🧩 AMA Autism")
95
+ st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
96
  query = st.text_input("Please ask me anything about autism ✨")
97
 
98
  if query:
faiss_index/index.py CHANGED
@@ -18,34 +18,43 @@ def fetch_arxiv_papers(query, max_results=10):
18
  """Fetch papers from arXiv and format them for RAG"""
19
  client = arxiv.Client()
20
 
21
- # Construct a more focused search query
22
- search_terms = query.lower().split()
23
- if 'autism' not in search_terms:
24
- search_terms.insert(0, 'autism')
25
 
26
- # Add specific category filters for medical and biological papers
27
- search_query = f"({' AND '.join(search_terms)}) AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)"
 
 
 
28
 
29
  search = arxiv.Search(
30
  query=search_query,
31
- max_results=max_results,
32
  sort_by=arxiv.SortCriterion.Relevance
33
  )
34
 
35
- results = list(client.results(search))
36
- papers = []
37
-
38
- # Filter results to ensure they're relevant to autism
39
- for i, result in enumerate(results):
40
- if 'autism' in result.title.lower() or 'autism' in result.summary.lower():
41
- papers.append({
42
- "id": str(i),
43
- "text": result.summary,
44
- "title": result.title
45
- })
46
-
47
- logging.info(f"Fetched {len(papers)} relevant papers from arXiv")
48
- return papers
 
 
 
 
 
 
 
49
 
50
  def build_faiss_index(papers, dataset_dir=DATASET_DIR):
51
  """Build and save dataset with FAISS index for RAG"""
 
18
  """Fetch papers from arXiv and format them for RAG"""
19
  client = arxiv.Client()
20
 
21
+ # Clean and prepare the search query
22
+ query = query.replace('and', '').strip() # Remove 'and' as it's treated as AND operator
23
+ terms = [term.strip() for term in query.split() if term.strip()]
 
24
 
25
+ # Create a more flexible search query
26
+ search_query = ' OR '.join([f'abs:"{term}" OR ti:"{term}"' for term in terms])
27
+ search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)'
28
+
29
+ logging.info(f"Searching arXiv with query: {search_query}")
30
 
31
  search = arxiv.Search(
32
  query=search_query,
33
+ max_results=max_results * 2, # Get more results to filter
34
  sort_by=arxiv.SortCriterion.Relevance
35
  )
36
 
37
+ try:
38
+ results = list(client.results(search))
39
+ papers = []
40
+
41
+ for i, result in enumerate(results):
42
+ # Include paper if it contains any of the search terms
43
+ text = (result.title + " " + result.summary).lower()
44
+ if any(term.lower() in text for term in terms):
45
+ papers.append({
46
+ "id": str(i),
47
+ "text": result.summary,
48
+ "title": result.title
49
+ })
50
+ if len(papers) >= max_results:
51
+ break
52
+
53
+ logging.info(f"Found {len(papers)} relevant papers from arXiv")
54
+ return papers
55
+ except Exception as e:
56
+ logging.error(f"Error fetching papers from arXiv: {str(e)}")
57
+ return []
58
 
59
  def build_faiss_index(papers, dataset_dir=DATASET_DIR):
60
  """Build and save dataset with FAISS index for RAG"""