wakeupmh commited on
Commit
7133da4
·
1 Parent(s): 158b752

refactor: using smol

Browse files
Files changed (5) hide show
  1. .DS_Store +0 -0
  2. _old_app.py +0 -203
  3. app.py +13 -14
  4. faiss_index/__init__.py +0 -1
  5. faiss_index/index.py +0 -232
.DS_Store ADDED
Binary file (6.15 kB). View file
 
_old_app.py DELETED
@@ -1,203 +0,0 @@
1
- import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
3
- import os
4
- from datasets import load_from_disk, Dataset
5
- import torch
6
- import logging
7
- import pandas as pd
8
- import arxiv
9
- import requests
10
- import xml.etree.ElementTree as ET
11
- from agno.embedder.huggingface import HuggingfaceCustomEmbedder
12
- from agno.vectordb.lancedb import LanceDb, SearchType
13
-
14
- # Configure logging
15
- logging.basicConfig(level=logging.INFO)
16
-
17
- # Define data paths and constants
18
- DATA_DIR = "/data" if os.path.exists("/data") else "."
19
- DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
20
- DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
21
- MODEL_PATH = "google/flan-t5-base" # Lighter model
22
-
23
- @st.cache_resource
24
- def load_local_model():
25
- """Load the local Hugging Face model"""
26
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
27
- model = AutoModelForSeq2SeqLM.from_pretrained(
28
- MODEL_PATH,
29
- torch_dtype=torch.float32, # Using float32 for CPU compatibility
30
- device_map="auto"
31
- )
32
- return model, tokenizer
33
-
34
- def fetch_arxiv_papers(query, max_results=5):
35
- """Fetch papers from arXiv"""
36
- client = arxiv.Client()
37
-
38
- # Clean and prepare the search query
39
- search_query = f"ti:{query} OR abs:{query} AND cat:q-bio"
40
-
41
- # Search arXiv
42
- search = arxiv.Search(
43
- query=search_query,
44
- max_results=max_results,
45
- sort_by=arxiv.SortCriterion.Relevance
46
- )
47
-
48
- papers = []
49
- for result in client.results(search):
50
- papers.append({
51
- 'title': result.title,
52
- 'abstract': result.summary,
53
- 'url': result.pdf_url,
54
- 'published': result.published
55
- })
56
-
57
- return papers
58
-
59
- def fetch_pubmed_papers(query, max_results=5):
60
- """Fetch papers from PubMed"""
61
- base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
62
-
63
- # Search for papers
64
- search_url = f"{base_url}/esearch.fcgi"
65
- search_params = {
66
- 'db': 'pubmed',
67
- 'term': query,
68
- 'retmax': max_results,
69
- 'sort': 'relevance',
70
- 'retmode': 'xml'
71
- }
72
-
73
- papers = []
74
- try:
75
- # Get paper IDs
76
- response = requests.get(search_url, params=search_params)
77
- root = ET.fromstring(response.content)
78
- id_list = [id_elem.text for id_elem in root.findall('.//Id')]
79
-
80
- if not id_list:
81
- return papers
82
-
83
- # Fetch paper details
84
- fetch_url = f"{base_url}/efetch.fcgi"
85
- fetch_params = {
86
- 'db': 'pubmed',
87
- 'id': ','.join(id_list),
88
- 'retmode': 'xml'
89
- }
90
-
91
- response = requests.get(fetch_url, params=fetch_params)
92
- articles = ET.fromstring(response.content)
93
-
94
- for article in articles.findall('.//PubmedArticle'):
95
- title = article.find('.//ArticleTitle')
96
- abstract = article.find('.//Abstract/AbstractText')
97
-
98
- papers.append({
99
- 'title': title.text if title is not None else 'No title available',
100
- 'abstract': abstract.text if abstract is not None else 'No abstract available',
101
- 'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/",
102
- 'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown'
103
- })
104
-
105
- except Exception as e:
106
- st.error(f"Error fetching PubMed papers: {str(e)}")
107
-
108
- return papers
109
-
110
- def search_research_papers(query):
111
- """Search both arXiv and PubMed for papers"""
112
- arxiv_papers = fetch_arxiv_papers(query)
113
- pubmed_papers = fetch_pubmed_papers(query)
114
-
115
- # Combine and format papers
116
- all_papers = []
117
- for paper in arxiv_papers + pubmed_papers:
118
- all_papers.append({
119
- 'title': paper['title'],
120
- 'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}",
121
- 'url': paper['url'],
122
- 'published': paper['published']
123
- })
124
-
125
- return pd.DataFrame(all_papers)
126
-
127
- def generate_answer(question, context, max_length=512):
128
- """Generate a comprehensive answer using the local model"""
129
- model, tokenizer = load_local_model()
130
-
131
- # Format the context as a structured query
132
- prompt = f"""Based on the following research papers about autism, provide a detailed answer:
133
-
134
- Question: {question}
135
-
136
- Research Context:
137
- {context}
138
-
139
- Please analyze:
140
- 1. Main findings
141
- 2. Research methods
142
- 3. Clinical implications
143
- 4. Limitations
144
-
145
- Answer:"""
146
-
147
- # Generate response
148
- inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)
149
-
150
- with torch.inference_mode():
151
- outputs = model.generate(
152
- **inputs,
153
- max_length=max_length,
154
- num_beams=4,
155
- temperature=0.7,
156
- top_p=0.9,
157
- repetition_penalty=1.2,
158
- early_stopping=True
159
- )
160
-
161
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
162
-
163
- # Format the response
164
- formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
165
-
166
- return formatted_response
167
-
168
- # Streamlit App
169
- st.title("🧩 AMA Autism")
170
- st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
171
- query = st.text_input("Please ask me anything about autism ✨")
172
-
173
- if query:
174
- with st.status("Searching for answers...") as status:
175
- # Search for papers
176
- df = search_research_papers(query)
177
- st.write("Searching for data in PubMed and arXiv...")
178
- st.write("Data found!")
179
-
180
- # Get relevant context
181
- context = "\n".join([
182
- f"{text[:1000]}" for text in df['text'].head(3)
183
- ])
184
-
185
- # Generate answer
186
- answer = generate_answer(query, context)
187
- st.write("Generating answer...")
188
- status.update(
189
- label="Search complete!", state="complete", expanded=False
190
- )
191
- if answer and not answer.isspace():
192
- st.success("Answer found!")
193
- st.write(answer)
194
-
195
- st.write("### Sources used:")
196
- for _, row in df.head(3).iterrows():
197
- st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})")
198
- st.write(f"**Summary:** {row['text'][:200]}...")
199
- st.write("---")
200
- else:
201
- st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
202
- if df.empty:
203
- st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -17,7 +17,7 @@ DATA_DIR = "/data" if os.path.exists("/data") else "."
17
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
18
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
19
  TOKENIZER_MODEL = "google/flan-t5-small"
20
- SUMMARIZATION_MODEL= "Falconsai/text_summarization"
21
  # SUMMARIZATION_MODEL="rhaymison/t5-portuguese-small-summarization"
22
 
23
  @st.cache_resource
@@ -211,26 +211,25 @@ def generate_answer(question, context, max_length=512):
211
 
212
  # Format the input for T5 (it expects a specific format)
213
  input_text = f"""Objective:
214
- Generate a clear, informative, and well-structured answer about autism, making the content easy to understand for a general audience. Use the provided research papers to support your explanations.
215
 
216
  Question: {clean_question}
217
-
218
  Research Papers:
219
  {clean_context}
220
 
221
  Instructions:
222
-
223
- Start with a simple explanation
224
- - Clearly define what autism is in an easy-to-understand way, avoiding overly complex technical terms.
225
  - Use real-life examples
226
- - Whenever possible, include practical examples to illustrate key concepts.
227
- - Relates research in an accessible way
228
- - Instead of just referencing papers, explain their findings in a way that anyone can understand. Example: "A study from X University found that..."
229
- - Avoid scientific jargon
230
- - If a technical term is necessary, provide a simple explanation.
231
- - Organize the response into sections
232
- - Use lists and short paragraphs to improve readability.
233
- Write your answer in a friendly and accessible tone, ensuring that anyone, regardless of their background knowledge, can understand the information provided."""
 
234
 
235
  try:
236
  # T5 expects a specific format for the input
 
17
  DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
18
  DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
19
  TOKENIZER_MODEL = "google/flan-t5-small"
20
+ SUMMARIZATION_MODEL= "HuggingFaceTB/SmolVLM-256M-Instruct"
21
  # SUMMARIZATION_MODEL="rhaymison/t5-portuguese-small-summarization"
22
 
23
  @st.cache_resource
 
211
 
212
  # Format the input for T5 (it expects a specific format)
213
  input_text = f"""Objective:
214
+ Provide a clear, simple, and well-structured answer about autism that is easy to understand for a general audience. Use the provided research papers as references.
215
 
216
  Question: {clean_question}
 
217
  Research Papers:
218
  {clean_context}
219
 
220
  Instructions:
221
+ Start with a simple definition
222
+ - Explain what autism is in a short and clear way, avoiding technical terms.
 
223
  - Use real-life examples
224
+ - Give practical and relatable examples to help illustrate key points.
225
+ - Explain research in simple words
226
+ - Instead of just citing studies, summarize their key findings in a way that anyone can understand. Example: "A study from X University found that..."
227
+ - Avoid complex words
228
+ - If a scientific term is needed, provide a short and simple explanation.
229
+ - Use clear formatting
230
+ - Write in short paragraphs, bullet points, or numbered lists to improve readability.
231
+ - Keep a friendly tone
232
+ - Make the response engaging and easy to follow, so people without prior knowledge can understand."""
233
 
234
  try:
235
  # T5 expects a specific format for the input
faiss_index/__init__.py DELETED
@@ -1 +0,0 @@
1
- # This file makes the faiss_index directory a Python package
 
 
faiss_index/index.py DELETED
@@ -1,232 +0,0 @@
1
- import numpy as np
2
- import faiss
3
- import arxiv
4
- from datasets import Dataset
5
- import os
6
- from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
7
- import torch
8
- import logging
9
- import requests
10
- from datetime import datetime
11
- import xml.etree.ElementTree as ET
12
- from time import sleep
13
-
14
- # Configure logging
15
- logging.basicConfig(level=logging.INFO)
16
-
17
- # Define data paths
18
- DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
19
- DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
20
-
21
- def fetch_arxiv_papers(query, max_results=10):
22
- """Fetch papers from arXiv and format them for RAG"""
23
- client = arxiv.Client()
24
-
25
- # Clean and prepare the search query
26
- query = query.replace('and', '').strip()
27
- terms = [term.strip() for term in query.split() if term.strip()]
28
-
29
- # Always include autism in the search
30
- if 'autism' not in [t.lower() for t in terms]:
31
- terms.insert(0, 'autism')
32
-
33
- # Create search query with required autism term
34
- term_queries = []
35
- for term in terms:
36
- if term.lower() != "autism":
37
- term_queries.append(f'abs:"{term}" OR ti:"{term}"')
38
-
39
- search_query = '(abs:"autism" OR ti:"autism")'
40
- if term_queries:
41
- search_query += f' AND ({" OR ".join(term_queries)})'
42
- search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)'
43
-
44
- logging.info(f"Searching arXiv with query: {search_query}")
45
-
46
- search = arxiv.Search(
47
- query=search_query,
48
- max_results=max_results * 2,
49
- sort_by=arxiv.SortCriterion.Relevance
50
- )
51
-
52
- try:
53
- results = list(client.results(search))
54
- papers = []
55
-
56
- for i, result in enumerate(results):
57
- # Only include papers that mention autism
58
- text = (result.title + " " + result.summary).lower()
59
- if 'autism' in text:
60
- papers.append({
61
- "id": f"arxiv_{i}",
62
- "text": result.summary,
63
- "title": result.title,
64
- "url": result.entry_id,
65
- "published": result.published.strftime("%Y-%m-%d")
66
- })
67
- if len(papers) >= max_results:
68
- break
69
-
70
- logging.info(f"Found {len(papers)} relevant papers about autism from arXiv")
71
- return papers
72
- except Exception as e:
73
- logging.error(f"Error fetching papers from arXiv: {str(e)}")
74
- return []
75
-
76
- def fetch_pubmed_papers(query, max_results=10):
77
- """Fetch papers from PubMed using E-utilities"""
78
- base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
79
-
80
- # Search for papers
81
- search_url = f"{base_url}/esearch.fcgi"
82
- search_params = {
83
- 'db': 'pubmed',
84
- 'term': f"{query} AND autism",
85
- 'retmax': max_results,
86
- 'sort': 'relevance',
87
- 'retmode': 'xml'
88
- }
89
-
90
- try:
91
- # Get paper IDs
92
- response = requests.get(search_url, params=search_params)
93
- root = ET.fromstring(response.content)
94
- id_list = [id_elem.text for id_elem in root.findall('.//Id')]
95
-
96
- if not id_list:
97
- return []
98
-
99
- # Fetch paper details
100
- fetch_url = f"{base_url}/efetch.fcgi"
101
- fetch_params = {
102
- 'db': 'pubmed',
103
- 'id': ','.join(id_list),
104
- 'retmode': 'xml'
105
- }
106
-
107
- response = requests.get(fetch_url, params=fetch_params)
108
- root = ET.fromstring(response.content)
109
- papers = []
110
-
111
- for article in root.findall('.//PubmedArticle'):
112
- try:
113
- # Extract article information
114
- title = article.find('.//ArticleTitle').text
115
- abstract = article.find('.//Abstract/AbstractText')
116
- abstract = abstract.text if abstract is not None else ""
117
-
118
- if 'autism' in (title + abstract).lower():
119
- pmid = article.find('.//PMID').text
120
- date = article.find('.//PubDate')
121
- year = date.find('Year').text if date.find('Year') is not None else "Unknown"
122
-
123
- papers.append({
124
- "id": f"pubmed_{pmid}",
125
- "text": abstract,
126
- "title": title,
127
- "url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/",
128
- "published": year
129
- })
130
- except Exception as e:
131
- logging.warning(f"Error processing PubMed article: {str(e)}")
132
- continue
133
-
134
- logging.info(f"Found {len(papers)} relevant papers from PubMed")
135
- return papers
136
-
137
- except Exception as e:
138
- logging.error(f"Error fetching papers from PubMed: {str(e)}")
139
- return []
140
-
141
- def fetch_papers(query, max_results=10):
142
- """Fetch papers from both arXiv and PubMed"""
143
- arxiv_papers = fetch_arxiv_papers(query, max_results=max_results)
144
- sleep(1) # Respect rate limits
145
- pubmed_papers = fetch_pubmed_papers(query, max_results=max_results)
146
-
147
- # Combine and deduplicate papers based on title similarity
148
- all_papers = arxiv_papers + pubmed_papers
149
- unique_papers = []
150
- seen_titles = set()
151
-
152
- for paper in all_papers:
153
- title_lower = paper['title'].lower()
154
- if not any(title_lower in seen_title or seen_title in title_lower for seen_title in seen_titles):
155
- unique_papers.append(paper)
156
- seen_titles.add(title_lower)
157
-
158
- # Sort by relevance (papers with 'autism' in title first)
159
- unique_papers.sort(key=lambda x: 'autism' in x['title'].lower(), reverse=True)
160
- return unique_papers[:max_results]
161
-
162
- def build_faiss_index(papers, dataset_dir=DATASET_DIR):
163
- """Build and save dataset with FAISS index for RAG"""
164
- if not papers:
165
- logging.warning("No papers found. Creating empty dataset.")
166
- # Create an empty dataset with the expected structure
167
- dataset = Dataset.from_dict({
168
- "text": [],
169
- "embeddings": [],
170
- "title": []
171
- })
172
- os.makedirs(dataset_dir, exist_ok=True)
173
- dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
174
- return dataset_dir
175
-
176
- # Initialize smaller DPR encoder
177
- ctx_encoder = DPRContextEncoder.from_pretrained(
178
- "facebook/dpr-ctx_encoder-single-nq-base",
179
- torch_dtype=torch.float16,
180
- low_cpu_mem_usage=True
181
- )
182
- ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
183
-
184
- # Create embeddings with smaller batches and memory optimization
185
- texts = [p["text"] for p in papers]
186
- embeddings = []
187
- batch_size = 4 # Smaller batch size
188
-
189
- with torch.inference_mode():
190
- for i in range(0, len(texts), batch_size):
191
- batch_texts = texts[i:i + batch_size]
192
- inputs = ctx_tokenizer(
193
- batch_texts,
194
- max_length=256, # Reduced from default
195
- padding=True,
196
- truncation=True,
197
- return_tensors="pt"
198
- )
199
- outputs = ctx_encoder(**inputs)
200
- embeddings.extend(outputs.pooler_output.cpu().numpy())
201
-
202
- # Clear memory
203
- del outputs
204
- if torch.cuda.is_available():
205
- torch.cuda.empty_cache()
206
-
207
- # Convert to numpy array and build FAISS index
208
- embeddings = np.array(embeddings, dtype=np.float32) # Ensure float32 type
209
- dimension = embeddings.shape[1]
210
-
211
- # Normalize the vectors manually
212
- norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
213
- embeddings = embeddings / norms
214
-
215
- # Create FAISS index
216
- index = faiss.IndexFlatIP(dimension)
217
- index.add(embeddings)
218
-
219
- # Create and save the dataset
220
- dataset = Dataset.from_dict({
221
- "text": texts,
222
- "embeddings": embeddings.tolist(), # Convert to list for storage
223
- "title": [p["title"] for p in papers]
224
- })
225
-
226
- # Create directory if it doesn't exist
227
- os.makedirs(dataset_dir, exist_ok=True)
228
-
229
- # Save dataset
230
- dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
231
- logging.info(f"Dataset saved to {dataset_dir}")
232
- return dataset_dir