wakeupmh commited on
Commit
cc0b0d6
·
1 Parent(s): 92c1c48

fix: shearch

Browse files
Files changed (2) hide show
  1. app.py +32 -23
  2. faiss_index/index.py +24 -4
app.py CHANGED
@@ -30,12 +30,13 @@ def load_models():
30
 
31
  @st.cache_data(ttl=3600) # Cache for 1 hour
32
  def load_dataset(query):
33
- # Create initial dataset if it doesn't exist
34
- if not os.path.exists(DATASET_PATH):
35
- with st.spinner("Building initial dataset from autism research papers..."):
36
- import faiss_index.index as idx
37
- papers = idx.fetch_arxiv_papers(f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", max_results=25) # Reduced max results
38
- idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
 
39
 
40
  # Load and convert to pandas for easier handling
41
  dataset = load_from_disk(DATASET_PATH)
@@ -45,20 +46,24 @@ def load_dataset(query):
45
  })
46
  return df
47
 
48
- def generate_answer(question, context, max_length=150): # Reduced max length
49
  tokenizer, model = load_models()
50
 
51
- # Add context about medical information
52
- prompt = f"Based on scientific research about autism and health: question: {question} context: {context}"
 
 
 
 
53
 
54
  # Optimize input processing
55
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
56
 
57
- with torch.inference_mode(): # More efficient than no_grad
58
  outputs = model.generate(
59
  **inputs,
60
  max_length=max_length,
61
- num_beams=2, # Reduced beam search
62
  temperature=0.7,
63
  top_p=0.9,
64
  repetition_penalty=1.2,
@@ -71,7 +76,11 @@ def generate_answer(question, context, max_length=150): # Reduced max length
71
  if torch.cuda.is_available():
72
  torch.cuda.empty_cache()
73
 
74
- return answer if answer and not answer.isspace() else "I cannot find a specific answer to this question in the provided context."
 
 
 
 
75
 
76
  # Streamlit App
77
  st.title("🧩 AMA Autism")
@@ -90,14 +99,14 @@ if query:
90
  # Generate answer
91
  answer = generate_answer(query, context)
92
 
93
- if answer and not answer.isspace():
94
- st.success("Answer found!")
95
- st.write(answer)
96
-
97
- st.write("### Sources Used:")
98
- for _, row in df.head(3).iterrows():
99
- st.write(f"**Title:** {row['title']}")
100
- st.write(f"**Summary:** {row['text'][:200]}...")
101
- st.write("---")
102
- else:
103
- st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
 
30
 
31
  @st.cache_data(ttl=3600) # Cache for 1 hour
32
  def load_dataset(query):
33
+ # Always fetch fresh results for the specific query
34
+ with st.spinner("Searching autism research papers..."):
35
+ import faiss_index.index as idx
36
+ # Make the query more specific to autism and b12
37
+ search_query = f"autism {query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)"
38
+ papers = idx.fetch_arxiv_papers(search_query, max_results=25)
39
+ idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
40
 
41
  # Load and convert to pandas for easier handling
42
  dataset = load_from_disk(DATASET_PATH)
 
46
  })
47
  return df
48
 
49
+ def generate_answer(question, context, max_length=150):
50
  tokenizer, model = load_models()
51
 
52
+ # Improve prompt to focus on autism-related information
53
+ prompt = f"""Based on scientific research about autism, answer the following question.
54
+ If the context doesn't contain relevant information about autism, respond with 'I cannot find specific information about this topic in the autism research papers.'
55
+
56
+ Question: {question}
57
+ Context: {context}"""
58
 
59
  # Optimize input processing
60
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
61
 
62
+ with torch.inference_mode():
63
  outputs = model.generate(
64
  **inputs,
65
  max_length=max_length,
66
+ num_beams=2,
67
  temperature=0.7,
68
  top_p=0.9,
69
  repetition_penalty=1.2,
 
76
  if torch.cuda.is_available():
77
  torch.cuda.empty_cache()
78
 
79
+ # Additional validation of the answer
80
+ if not answer or answer.isspace() or "cannot find" in answer.lower():
81
+ return "I cannot find specific information about this topic in the autism research papers."
82
+
83
+ return answer
84
 
85
  # Streamlit App
86
  st.title("🧩 AMA Autism")
 
99
  # Generate answer
100
  answer = generate_answer(query, context)
101
 
102
+ if answer and not answer.isspace():
103
+ st.success("Answer found!")
104
+ st.write(answer)
105
+
106
+ st.write("### Sources Used:")
107
+ for _, row in df.head(3).iterrows():
108
+ st.write(f"**Title:** {row['title']}")
109
+ st.write(f"**Summary:** {row['text'][:200]}...")
110
+ st.write("---")
111
+ else:
112
+ st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
faiss_index/index.py CHANGED
@@ -17,14 +17,34 @@ DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
17
  def fetch_arxiv_papers(query, max_results=10):
18
  """Fetch papers from arXiv and format them for RAG"""
19
  client = arxiv.Client()
 
 
 
 
 
 
 
 
 
20
  search = arxiv.Search(
21
- query=f"{query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)", # Focus on biology and medical categories
22
  max_results=max_results,
23
- sort_by=arxiv.SortCriterion.Relevance # Changed to relevance-based sorting
24
  )
 
25
  results = list(client.results(search))
26
- papers = [{"id": str(i), "text": result.summary, "title": result.title} for i, result in enumerate(results)]
27
- logging.info(f"Fetched {len(papers)} papers from arXiv")
 
 
 
 
 
 
 
 
 
 
28
  return papers
29
 
30
  def build_faiss_index(papers, dataset_dir=DATASET_DIR):
 
17
  def fetch_arxiv_papers(query, max_results=10):
18
  """Fetch papers from arXiv and format them for RAG"""
19
  client = arxiv.Client()
20
+
21
+ # Construct a more focused search query
22
+ search_terms = query.lower().split()
23
+ if 'autism' not in search_terms:
24
+ search_terms.insert(0, 'autism')
25
+
26
+ # Add specific category filters for medical and biological papers
27
+ search_query = f"({' AND '.join(search_terms)}) AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)"
28
+
29
  search = arxiv.Search(
30
+ query=search_query,
31
  max_results=max_results,
32
+ sort_by=arxiv.SortCriterion.Relevance
33
  )
34
+
35
  results = list(client.results(search))
36
+ papers = []
37
+
38
+ # Filter results to ensure they're relevant to autism
39
+ for i, result in enumerate(results):
40
+ if 'autism' in result.title.lower() or 'autism' in result.summary.lower():
41
+ papers.append({
42
+ "id": str(i),
43
+ "text": result.summary,
44
+ "title": result.title
45
+ })
46
+
47
+ logging.info(f"Fetched {len(papers)} relevant papers from arXiv")
48
  return papers
49
 
50
  def build_faiss_index(papers, dataset_dir=DATASET_DIR):