wakeupmh commited on
Commit
4660a83
·
1 Parent(s): cb9a068

test: agno

Browse files
Files changed (3) hide show
  1. _old_app.py +203 -0
  2. app.py +134 -72
  3. requirements.txt +4 -5
_old_app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
+ import os
4
+ from datasets import load_from_disk, Dataset
5
+ import torch
6
+ import logging
7
+ import pandas as pd
8
+ import arxiv
9
+ import requests
10
+ import xml.etree.ElementTree as ET
11
+ from agno.embedder.huggingface import HuggingfaceCustomEmbedder
12
+ from agno.vectordb.lancedb import LanceDb, SearchType
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+
17
+ # Define data paths and constants
18
+ DATA_DIR = "/data" if os.path.exists("/data") else "."
19
+ DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
20
+ DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
21
+ MODEL_PATH = "google/flan-t5-base" # Lighter model
22
+
23
+ @st.cache_resource
24
+ def load_local_model():
25
+ """Load the local Hugging Face model"""
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
27
+ model = AutoModelForSeq2SeqLM.from_pretrained(
28
+ MODEL_PATH,
29
+ torch_dtype=torch.float32, # Using float32 for CPU compatibility
30
+ device_map="auto"
31
+ )
32
+ return model, tokenizer
33
+
34
+ def fetch_arxiv_papers(query, max_results=5):
35
+ """Fetch papers from arXiv"""
36
+ client = arxiv.Client()
37
+
38
+ # Clean and prepare the search query
39
+ search_query = f"ti:{query} OR abs:{query} AND cat:q-bio"
40
+
41
+ # Search arXiv
42
+ search = arxiv.Search(
43
+ query=search_query,
44
+ max_results=max_results,
45
+ sort_by=arxiv.SortCriterion.Relevance
46
+ )
47
+
48
+ papers = []
49
+ for result in client.results(search):
50
+ papers.append({
51
+ 'title': result.title,
52
+ 'abstract': result.summary,
53
+ 'url': result.pdf_url,
54
+ 'published': result.published
55
+ })
56
+
57
+ return papers
58
+
59
+ def fetch_pubmed_papers(query, max_results=5):
60
+ """Fetch papers from PubMed"""
61
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
62
+
63
+ # Search for papers
64
+ search_url = f"{base_url}/esearch.fcgi"
65
+ search_params = {
66
+ 'db': 'pubmed',
67
+ 'term': query,
68
+ 'retmax': max_results,
69
+ 'sort': 'relevance',
70
+ 'retmode': 'xml'
71
+ }
72
+
73
+ papers = []
74
+ try:
75
+ # Get paper IDs
76
+ response = requests.get(search_url, params=search_params)
77
+ root = ET.fromstring(response.content)
78
+ id_list = [id_elem.text for id_elem in root.findall('.//Id')]
79
+
80
+ if not id_list:
81
+ return papers
82
+
83
+ # Fetch paper details
84
+ fetch_url = f"{base_url}/efetch.fcgi"
85
+ fetch_params = {
86
+ 'db': 'pubmed',
87
+ 'id': ','.join(id_list),
88
+ 'retmode': 'xml'
89
+ }
90
+
91
+ response = requests.get(fetch_url, params=fetch_params)
92
+ articles = ET.fromstring(response.content)
93
+
94
+ for article in articles.findall('.//PubmedArticle'):
95
+ title = article.find('.//ArticleTitle')
96
+ abstract = article.find('.//Abstract/AbstractText')
97
+
98
+ papers.append({
99
+ 'title': title.text if title is not None else 'No title available',
100
+ 'abstract': abstract.text if abstract is not None else 'No abstract available',
101
+ 'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/",
102
+ 'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown'
103
+ })
104
+
105
+ except Exception as e:
106
+ st.error(f"Error fetching PubMed papers: {str(e)}")
107
+
108
+ return papers
109
+
110
+ def search_research_papers(query):
111
+ """Search both arXiv and PubMed for papers"""
112
+ arxiv_papers = fetch_arxiv_papers(query)
113
+ pubmed_papers = fetch_pubmed_papers(query)
114
+
115
+ # Combine and format papers
116
+ all_papers = []
117
+ for paper in arxiv_papers + pubmed_papers:
118
+ all_papers.append({
119
+ 'title': paper['title'],
120
+ 'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}",
121
+ 'url': paper['url'],
122
+ 'published': paper['published']
123
+ })
124
+
125
+ return pd.DataFrame(all_papers)
126
+
127
+ def generate_answer(question, context, max_length=512):
128
+ """Generate a comprehensive answer using the local model"""
129
+ model, tokenizer = load_local_model()
130
+
131
+ # Format the context as a structured query
132
+ prompt = f"""Based on the following research papers about autism, provide a detailed answer:
133
+
134
+ Question: {question}
135
+
136
+ Research Context:
137
+ {context}
138
+
139
+ Please analyze:
140
+ 1. Main findings
141
+ 2. Research methods
142
+ 3. Clinical implications
143
+ 4. Limitations
144
+
145
+ Answer:"""
146
+
147
+ # Generate response
148
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)
149
+
150
+ with torch.inference_mode():
151
+ outputs = model.generate(
152
+ **inputs,
153
+ max_length=max_length,
154
+ num_beams=4,
155
+ temperature=0.7,
156
+ top_p=0.9,
157
+ repetition_penalty=1.2,
158
+ early_stopping=True
159
+ )
160
+
161
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
+
163
+ # Format the response
164
+ formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
165
+
166
+ return formatted_response
167
+
168
+ # Streamlit App
169
+ st.title("🧩 AMA Autism")
170
+ st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
171
+ query = st.text_input("Please ask me anything about autism ✨")
172
+
173
+ if query:
174
+ with st.status("Searching for answers...") as status:
175
+ # Search for papers
176
+ df = search_research_papers(query)
177
+ st.write("Searching for data in PubMed and arXiv...")
178
+ st.write("Data found!")
179
+
180
+ # Get relevant context
181
+ context = "\n".join([
182
+ f"{text[:1000]}" for text in df['text'].head(3)
183
+ ])
184
+
185
+ # Generate answer
186
+ answer = generate_answer(query, context)
187
+ st.write("Generating answer...")
188
+ status.update(
189
+ label="Search complete!", state="complete", expanded=False
190
+ )
191
+ if answer and not answer.isspace():
192
+ st.success("Answer found!")
193
+ st.write(answer)
194
+
195
+ st.write("### Sources used:")
196
+ for _, row in df.head(3).iterrows():
197
+ st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})")
198
+ st.write(f"**Summary:** {row['text'][:200]}...")
199
+ st.write("---")
200
+ else:
201
+ st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
202
+ if df.empty:
203
+ st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.")
app.py CHANGED
@@ -1,109 +1,169 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
  import os
4
  from datasets import load_from_disk, Dataset
5
  import torch
6
  import logging
7
  import pandas as pd
 
 
 
 
 
8
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
 
12
- # Define data paths
13
  DATA_DIR = "/data" if os.path.exists("/data") else "."
14
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
15
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
 
16
 
17
- # Cache models and dataset
18
  @st.cache_resource
19
- def load_models():
20
- model_name = "google/flan-t5-small" # Lighter model
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
22
  model = AutoModelForSeq2SeqLM.from_pretrained(
23
- model_name,
24
- torch_dtype=torch.float16,
25
- low_cpu_mem_usage=True,
26
- device_map='auto',
27
- max_memory={'cpu': '1GB'}
28
  )
29
- return tokenizer, model
30
 
31
- @st.cache_data(ttl=3600) # Cache for 1 hour
32
- def load_dataset(query):
33
- # Always fetch fresh results for the specific query
34
- with st.spinner("Searching research papers from arXiv and PubMed..."):
35
- import faiss_index.index as idx
36
- # Ensure both autism and the query terms are included
37
- if 'autism' not in query.lower():
38
- search_query = f"autism {query}"
39
- else:
40
- search_query = query
41
-
42
- papers = idx.fetch_papers(search_query, max_results=25) # This now fetches from both sources
43
 
44
- if not papers:
45
- st.warning("No relevant papers found. Please try rephrasing your question.")
46
- return pd.DataFrame(columns=['title', 'text', 'url', 'published'])
47
-
48
- idx.build_faiss_index(papers, dataset_dir=DATASET_DIR)
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Load and convert to pandas for easier handling
51
- dataset = load_from_disk(DATASET_PATH)
52
- df = pd.DataFrame({
53
- 'title': dataset['title'],
54
- 'text': dataset['text'],
55
- 'url': [p['url'] for p in papers],
56
- 'published': [p['published'] for p in papers]
57
- })
58
- return df
59
 
60
- def generate_answer(question, context, max_length=300):
61
- tokenizer, model = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- # Enhanced prompt for more detailed and structured answers
64
- prompt = f"""Based on scientific research about autism, provide a comprehensive and structured summary answering the following question.
65
- Include the following aspects when relevant:
66
- 1. Main findings and conclusions
67
- 2. Supporting evidence or research methods
68
- 3. Clinical implications or practical applications
69
- 4. Any limitations or areas needing further research
70
 
71
- Use clear headings and bullet points when appropriate to organize the information.
72
- If the context doesn't contain relevant information about autism, respond with 'I cannot find specific information about this topic in the autism research papers.'
 
 
 
 
 
 
 
73
 
74
- Question: {question}
75
- Context: {context}
 
 
 
76
 
77
- Detailed summary:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Optimize input processing
80
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768)
81
 
82
  with torch.inference_mode():
83
  outputs = model.generate(
84
  **inputs,
85
  max_length=max_length,
86
- num_beams=4,
87
- temperature=0.8,
88
  top_p=0.9,
89
- repetition_penalty=1.3,
90
- length_penalty=1.2,
91
  early_stopping=True
92
  )
93
 
94
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
95
-
96
- # Clear GPU memory if possible
97
- if torch.cuda.is_available():
98
- torch.cuda.empty_cache()
99
 
100
- # Enhanced answer validation and formatting
101
- if not answer or answer.isspace() or "cannot find" in answer.lower():
102
- return "I cannot find specific information about this topic in the autism research papers."
103
 
104
- # Format the answer with proper line breaks and structure
105
- formatted_answer = answer.replace(". ", ".\n").replace("• ", "\n• ")
106
- return formatted_answer
107
 
108
  # Streamlit App
109
  st.title("🧩 AMA Autism")
@@ -112,14 +172,16 @@ query = st.text_input("Please ask me anything about autism ✨")
112
 
113
  if query:
114
  with st.status("Searching for answers...") as status:
115
- # Load dataset
116
- df = load_dataset(query)
117
  st.write("Searching for data in PubMed and arXiv...")
 
 
118
  # Get relevant context
119
  context = "\n".join([
120
  f"{text[:1000]}" for text in df['text'].head(3)
121
  ])
122
- st.write("Data found!")
123
  # Generate answer
124
  answer = generate_answer(query, context)
125
  st.write("Generating answer...")
 
1
  import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
  import os
4
  from datasets import load_from_disk, Dataset
5
  import torch
6
  import logging
7
  import pandas as pd
8
+ import arxiv
9
+ import requests
10
+ import xml.etree.ElementTree as ET
11
+ from agno.embedder.huggingface import HuggingfaceCustomEmbedder
12
+ from agno.vectordb.lancedb import LanceDb, SearchType
13
 
14
  # Configure logging
15
  logging.basicConfig(level=logging.INFO)
16
 
17
+ # Define data paths and constants
18
  DATA_DIR = "/data" if os.path.exists("/data") else "."
19
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
20
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
21
+ MODEL_PATH = "google/flan-t5-base" # Lighter model
22
 
 
23
  @st.cache_resource
24
+ def load_local_model():
25
+ """Load the local Hugging Face model"""
26
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
27
  model = AutoModelForSeq2SeqLM.from_pretrained(
28
+ MODEL_PATH,
29
+ torch_dtype=torch.float32, # Using float32 for CPU compatibility
30
+ device_map="auto"
 
 
31
  )
32
+ return model, tokenizer
33
 
34
+ def fetch_arxiv_papers(query, max_results=5):
35
+ """Fetch papers from arXiv"""
36
+ client = arxiv.Client()
 
 
 
 
 
 
 
 
 
37
 
38
+ # Clean and prepare the search query
39
+ search_query = f"ti:{query} OR abs:{query} AND cat:q-bio"
40
+
41
+ # Search arXiv
42
+ search = arxiv.Search(
43
+ query=search_query,
44
+ max_results=max_results,
45
+ sort_by=arxiv.SortCriterion.Relevance
46
+ )
47
+
48
+ papers = []
49
+ for result in client.results(search):
50
+ papers.append({
51
+ 'title': result.title,
52
+ 'abstract': result.summary,
53
+ 'url': result.pdf_url,
54
+ 'published': result.published
55
+ })
56
 
57
+ return papers
 
 
 
 
 
 
 
 
58
 
59
+ def fetch_pubmed_papers(query, max_results=5):
60
+ """Fetch papers from PubMed"""
61
+ base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
62
+
63
+ # Search for papers
64
+ search_url = f"{base_url}/esearch.fcgi"
65
+ search_params = {
66
+ 'db': 'pubmed',
67
+ 'term': query,
68
+ 'retmax': max_results,
69
+ 'sort': 'relevance',
70
+ 'retmode': 'xml'
71
+ }
72
+
73
+ papers = []
74
+ try:
75
+ # Get paper IDs
76
+ response = requests.get(search_url, params=search_params)
77
+ root = ET.fromstring(response.content)
78
+ id_list = [id_elem.text for id_elem in root.findall('.//Id')]
79
+
80
+ if not id_list:
81
+ return papers
82
+
83
+ # Fetch paper details
84
+ fetch_url = f"{base_url}/efetch.fcgi"
85
+ fetch_params = {
86
+ 'db': 'pubmed',
87
+ 'id': ','.join(id_list),
88
+ 'retmode': 'xml'
89
+ }
90
+
91
+ response = requests.get(fetch_url, params=fetch_params)
92
+ articles = ET.fromstring(response.content)
93
+
94
+ for article in articles.findall('.//PubmedArticle'):
95
+ title = article.find('.//ArticleTitle')
96
+ abstract = article.find('.//Abstract/AbstractText')
97
+
98
+ papers.append({
99
+ 'title': title.text if title is not None else 'No title available',
100
+ 'abstract': abstract.text if abstract is not None else 'No abstract available',
101
+ 'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/",
102
+ 'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown'
103
+ })
104
+
105
+ except Exception as e:
106
+ st.error(f"Error fetching PubMed papers: {str(e)}")
107
 
108
+ return papers
109
+
110
+ def search_research_papers(query):
111
+ """Search both arXiv and PubMed for papers"""
112
+ arxiv_papers = fetch_arxiv_papers(query)
113
+ pubmed_papers = fetch_pubmed_papers(query)
 
114
 
115
+ # Combine and format papers
116
+ all_papers = []
117
+ for paper in arxiv_papers + pubmed_papers:
118
+ all_papers.append({
119
+ 'title': paper['title'],
120
+ 'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}",
121
+ 'url': paper['url'],
122
+ 'published': paper['published']
123
+ })
124
 
125
+ return pd.DataFrame(all_papers)
126
+
127
+ def generate_answer(question, context, max_length=512):
128
+ """Generate a comprehensive answer using the local model"""
129
+ model, tokenizer = load_local_model()
130
 
131
+ # Format the context as a structured query
132
+ prompt = f"""Based on the following research papers about autism, provide a detailed answer:
133
+
134
+ Question: {question}
135
+
136
+ Research Context:
137
+ {context}
138
+
139
+ Please analyze:
140
+ 1. Main findings
141
+ 2. Research methods
142
+ 3. Clinical implications
143
+ 4. Limitations
144
+
145
+ Answer:"""
146
 
147
+ # Generate response
148
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)
149
 
150
  with torch.inference_mode():
151
  outputs = model.generate(
152
  **inputs,
153
  max_length=max_length,
154
+ num_beams=4,
155
+ temperature=0.7,
156
  top_p=0.9,
157
+ repetition_penalty=1.2,
 
158
  early_stopping=True
159
  )
160
 
161
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
162
 
163
+ # Format the response
164
+ formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
 
165
 
166
+ return formatted_response
 
 
167
 
168
  # Streamlit App
169
  st.title("🧩 AMA Autism")
 
172
 
173
  if query:
174
  with st.status("Searching for answers...") as status:
175
+ # Search for papers
176
+ df = search_research_papers(query)
177
  st.write("Searching for data in PubMed and arXiv...")
178
+ st.write("Data found!")
179
+
180
  # Get relevant context
181
  context = "\n".join([
182
  f"{text[:1000]}" for text in df['text'].head(3)
183
  ])
184
+
185
  # Generate answer
186
  answer = generate_answer(query, context)
187
  st.write("Generating answer...")
requirements.txt CHANGED
@@ -1,13 +1,12 @@
1
  streamlit>=1.32.0
2
  transformers>=4.37.0
3
  datasets>=2.17.0
4
- sentence-transformers>=2.3.1
5
- faiss-cpu>=1.7.4
6
- arxiv>=2.1.0
7
  --extra-index-url https://download.pytorch.org/whl/cpu
8
  torch>=2.2.0
9
  accelerate>=0.26.0
10
- bitsandbytes>=0.41.1
11
  numpy>=1.24.0
12
  pandas>=2.2.0
13
- requests>=2.31.0
 
 
 
 
1
  streamlit>=1.32.0
2
  transformers>=4.37.0
3
  datasets>=2.17.0
 
 
 
4
  --extra-index-url https://download.pytorch.org/whl/cpu
5
  torch>=2.2.0
6
  accelerate>=0.26.0
 
7
  numpy>=1.24.0
8
  pandas>=2.2.0
9
+ requests>=2.31.0
10
+ arxiv>=2.1.0
11
+ lancedb>=0.3.3
12
+ tantivy>=0.19.2