Spaces:
Sleeping
Sleeping
fix: shearch
Browse files- app.py +32 -23
- faiss_index/index.py +24 -4
app.py
CHANGED
@@ -30,12 +30,13 @@ def load_models():
|
|
30 |
|
31 |
@st.cache_data(ttl=3600) # Cache for 1 hour
|
32 |
def load_dataset(query):
|
33 |
-
#
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
39 |
|
40 |
# Load and convert to pandas for easier handling
|
41 |
dataset = load_from_disk(DATASET_PATH)
|
@@ -45,20 +46,24 @@ def load_dataset(query):
|
|
45 |
})
|
46 |
return df
|
47 |
|
48 |
-
def generate_answer(question, context, max_length=150):
|
49 |
tokenizer, model = load_models()
|
50 |
|
51 |
-
#
|
52 |
-
prompt = f"Based on scientific research about autism
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# Optimize input processing
|
55 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
56 |
|
57 |
-
with torch.inference_mode():
|
58 |
outputs = model.generate(
|
59 |
**inputs,
|
60 |
max_length=max_length,
|
61 |
-
num_beams=2,
|
62 |
temperature=0.7,
|
63 |
top_p=0.9,
|
64 |
repetition_penalty=1.2,
|
@@ -71,7 +76,11 @@ def generate_answer(question, context, max_length=150): # Reduced max length
|
|
71 |
if torch.cuda.is_available():
|
72 |
torch.cuda.empty_cache()
|
73 |
|
74 |
-
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# Streamlit App
|
77 |
st.title("🧩 AMA Autism")
|
@@ -90,14 +99,14 @@ if query:
|
|
90 |
# Generate answer
|
91 |
answer = generate_answer(query, context)
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
30 |
|
31 |
@st.cache_data(ttl=3600) # Cache for 1 hour
|
32 |
def load_dataset(query):
|
33 |
+
# Always fetch fresh results for the specific query
|
34 |
+
with st.spinner("Searching autism research papers..."):
|
35 |
+
import faiss_index.index as idx
|
36 |
+
# Make the query more specific to autism and b12
|
37 |
+
search_query = f"autism {query} AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)"
|
38 |
+
papers = idx.fetch_arxiv_papers(search_query, max_results=25)
|
39 |
+
idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
|
40 |
|
41 |
# Load and convert to pandas for easier handling
|
42 |
dataset = load_from_disk(DATASET_PATH)
|
|
|
46 |
})
|
47 |
return df
|
48 |
|
49 |
+
def generate_answer(question, context, max_length=150):
|
50 |
tokenizer, model = load_models()
|
51 |
|
52 |
+
# Improve prompt to focus on autism-related information
|
53 |
+
prompt = f"""Based on scientific research about autism, answer the following question.
|
54 |
+
If the context doesn't contain relevant information about autism, respond with 'I cannot find specific information about this topic in the autism research papers.'
|
55 |
+
|
56 |
+
Question: {question}
|
57 |
+
Context: {context}"""
|
58 |
|
59 |
# Optimize input processing
|
60 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
|
61 |
|
62 |
+
with torch.inference_mode():
|
63 |
outputs = model.generate(
|
64 |
**inputs,
|
65 |
max_length=max_length,
|
66 |
+
num_beams=2,
|
67 |
temperature=0.7,
|
68 |
top_p=0.9,
|
69 |
repetition_penalty=1.2,
|
|
|
76 |
if torch.cuda.is_available():
|
77 |
torch.cuda.empty_cache()
|
78 |
|
79 |
+
# Additional validation of the answer
|
80 |
+
if not answer or answer.isspace() or "cannot find" in answer.lower():
|
81 |
+
return "I cannot find specific information about this topic in the autism research papers."
|
82 |
+
|
83 |
+
return answer
|
84 |
|
85 |
# Streamlit App
|
86 |
st.title("🧩 AMA Autism")
|
|
|
99 |
# Generate answer
|
100 |
answer = generate_answer(query, context)
|
101 |
|
102 |
+
if answer and not answer.isspace():
|
103 |
+
st.success("Answer found!")
|
104 |
+
st.write(answer)
|
105 |
+
|
106 |
+
st.write("### Sources Used:")
|
107 |
+
for _, row in df.head(3).iterrows():
|
108 |
+
st.write(f"**Title:** {row['title']}")
|
109 |
+
st.write(f"**Summary:** {row['text'][:200]}...")
|
110 |
+
st.write("---")
|
111 |
+
else:
|
112 |
+
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
|
faiss_index/index.py
CHANGED
@@ -17,14 +17,34 @@ DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
|
17 |
def fetch_arxiv_papers(query, max_results=10):
|
18 |
"""Fetch papers from arXiv and format them for RAG"""
|
19 |
client = arxiv.Client()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
search = arxiv.Search(
|
21 |
-
query=
|
22 |
max_results=max_results,
|
23 |
-
sort_by=arxiv.SortCriterion.Relevance
|
24 |
)
|
|
|
25 |
results = list(client.results(search))
|
26 |
-
papers = [
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
return papers
|
29 |
|
30 |
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
|
|
17 |
def fetch_arxiv_papers(query, max_results=10):
|
18 |
"""Fetch papers from arXiv and format them for RAG"""
|
19 |
client = arxiv.Client()
|
20 |
+
|
21 |
+
# Construct a more focused search query
|
22 |
+
search_terms = query.lower().split()
|
23 |
+
if 'autism' not in search_terms:
|
24 |
+
search_terms.insert(0, 'autism')
|
25 |
+
|
26 |
+
# Add specific category filters for medical and biological papers
|
27 |
+
search_query = f"({' AND '.join(search_terms)}) AND (cat:q-bio.NC OR cat:q-bio.QM OR cat:q-bio.GN OR cat:q-bio.CB OR cat:q-bio.MN)"
|
28 |
+
|
29 |
search = arxiv.Search(
|
30 |
+
query=search_query,
|
31 |
max_results=max_results,
|
32 |
+
sort_by=arxiv.SortCriterion.Relevance
|
33 |
)
|
34 |
+
|
35 |
results = list(client.results(search))
|
36 |
+
papers = []
|
37 |
+
|
38 |
+
# Filter results to ensure they're relevant to autism
|
39 |
+
for i, result in enumerate(results):
|
40 |
+
if 'autism' in result.title.lower() or 'autism' in result.summary.lower():
|
41 |
+
papers.append({
|
42 |
+
"id": str(i),
|
43 |
+
"text": result.summary,
|
44 |
+
"title": result.title
|
45 |
+
})
|
46 |
+
|
47 |
+
logging.info(f"Fetched {len(papers)} relevant papers from arXiv")
|
48 |
return papers
|
49 |
|
50 |
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|