Spaces:
Sleeping
Sleeping
refactor: using smol
Browse files- .DS_Store +0 -0
- _old_app.py +0 -203
- app.py +13 -14
- faiss_index/__init__.py +0 -1
- faiss_index/index.py +0 -232
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
_old_app.py
DELETED
@@ -1,203 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
3 |
-
import os
|
4 |
-
from datasets import load_from_disk, Dataset
|
5 |
-
import torch
|
6 |
-
import logging
|
7 |
-
import pandas as pd
|
8 |
-
import arxiv
|
9 |
-
import requests
|
10 |
-
import xml.etree.ElementTree as ET
|
11 |
-
from agno.embedder.huggingface import HuggingfaceCustomEmbedder
|
12 |
-
from agno.vectordb.lancedb import LanceDb, SearchType
|
13 |
-
|
14 |
-
# Configure logging
|
15 |
-
logging.basicConfig(level=logging.INFO)
|
16 |
-
|
17 |
-
# Define data paths and constants
|
18 |
-
DATA_DIR = "/data" if os.path.exists("/data") else "."
|
19 |
-
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
20 |
-
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
21 |
-
MODEL_PATH = "google/flan-t5-base" # Lighter model
|
22 |
-
|
23 |
-
@st.cache_resource
|
24 |
-
def load_local_model():
|
25 |
-
"""Load the local Hugging Face model"""
|
26 |
-
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
27 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
28 |
-
MODEL_PATH,
|
29 |
-
torch_dtype=torch.float32, # Using float32 for CPU compatibility
|
30 |
-
device_map="auto"
|
31 |
-
)
|
32 |
-
return model, tokenizer
|
33 |
-
|
34 |
-
def fetch_arxiv_papers(query, max_results=5):
|
35 |
-
"""Fetch papers from arXiv"""
|
36 |
-
client = arxiv.Client()
|
37 |
-
|
38 |
-
# Clean and prepare the search query
|
39 |
-
search_query = f"ti:{query} OR abs:{query} AND cat:q-bio"
|
40 |
-
|
41 |
-
# Search arXiv
|
42 |
-
search = arxiv.Search(
|
43 |
-
query=search_query,
|
44 |
-
max_results=max_results,
|
45 |
-
sort_by=arxiv.SortCriterion.Relevance
|
46 |
-
)
|
47 |
-
|
48 |
-
papers = []
|
49 |
-
for result in client.results(search):
|
50 |
-
papers.append({
|
51 |
-
'title': result.title,
|
52 |
-
'abstract': result.summary,
|
53 |
-
'url': result.pdf_url,
|
54 |
-
'published': result.published
|
55 |
-
})
|
56 |
-
|
57 |
-
return papers
|
58 |
-
|
59 |
-
def fetch_pubmed_papers(query, max_results=5):
|
60 |
-
"""Fetch papers from PubMed"""
|
61 |
-
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
62 |
-
|
63 |
-
# Search for papers
|
64 |
-
search_url = f"{base_url}/esearch.fcgi"
|
65 |
-
search_params = {
|
66 |
-
'db': 'pubmed',
|
67 |
-
'term': query,
|
68 |
-
'retmax': max_results,
|
69 |
-
'sort': 'relevance',
|
70 |
-
'retmode': 'xml'
|
71 |
-
}
|
72 |
-
|
73 |
-
papers = []
|
74 |
-
try:
|
75 |
-
# Get paper IDs
|
76 |
-
response = requests.get(search_url, params=search_params)
|
77 |
-
root = ET.fromstring(response.content)
|
78 |
-
id_list = [id_elem.text for id_elem in root.findall('.//Id')]
|
79 |
-
|
80 |
-
if not id_list:
|
81 |
-
return papers
|
82 |
-
|
83 |
-
# Fetch paper details
|
84 |
-
fetch_url = f"{base_url}/efetch.fcgi"
|
85 |
-
fetch_params = {
|
86 |
-
'db': 'pubmed',
|
87 |
-
'id': ','.join(id_list),
|
88 |
-
'retmode': 'xml'
|
89 |
-
}
|
90 |
-
|
91 |
-
response = requests.get(fetch_url, params=fetch_params)
|
92 |
-
articles = ET.fromstring(response.content)
|
93 |
-
|
94 |
-
for article in articles.findall('.//PubmedArticle'):
|
95 |
-
title = article.find('.//ArticleTitle')
|
96 |
-
abstract = article.find('.//Abstract/AbstractText')
|
97 |
-
|
98 |
-
papers.append({
|
99 |
-
'title': title.text if title is not None else 'No title available',
|
100 |
-
'abstract': abstract.text if abstract is not None else 'No abstract available',
|
101 |
-
'url': f"https://pubmed.ncbi.nlm.nih.gov/{article.find('.//PMID').text}/",
|
102 |
-
'published': article.find('.//PubDate/Year').text if article.find('.//PubDate/Year') is not None else 'Unknown'
|
103 |
-
})
|
104 |
-
|
105 |
-
except Exception as e:
|
106 |
-
st.error(f"Error fetching PubMed papers: {str(e)}")
|
107 |
-
|
108 |
-
return papers
|
109 |
-
|
110 |
-
def search_research_papers(query):
|
111 |
-
"""Search both arXiv and PubMed for papers"""
|
112 |
-
arxiv_papers = fetch_arxiv_papers(query)
|
113 |
-
pubmed_papers = fetch_pubmed_papers(query)
|
114 |
-
|
115 |
-
# Combine and format papers
|
116 |
-
all_papers = []
|
117 |
-
for paper in arxiv_papers + pubmed_papers:
|
118 |
-
all_papers.append({
|
119 |
-
'title': paper['title'],
|
120 |
-
'text': f"Title: {paper['title']}\nAbstract: {paper['abstract']}",
|
121 |
-
'url': paper['url'],
|
122 |
-
'published': paper['published']
|
123 |
-
})
|
124 |
-
|
125 |
-
return pd.DataFrame(all_papers)
|
126 |
-
|
127 |
-
def generate_answer(question, context, max_length=512):
|
128 |
-
"""Generate a comprehensive answer using the local model"""
|
129 |
-
model, tokenizer = load_local_model()
|
130 |
-
|
131 |
-
# Format the context as a structured query
|
132 |
-
prompt = f"""Based on the following research papers about autism, provide a detailed answer:
|
133 |
-
|
134 |
-
Question: {question}
|
135 |
-
|
136 |
-
Research Context:
|
137 |
-
{context}
|
138 |
-
|
139 |
-
Please analyze:
|
140 |
-
1. Main findings
|
141 |
-
2. Research methods
|
142 |
-
3. Clinical implications
|
143 |
-
4. Limitations
|
144 |
-
|
145 |
-
Answer:"""
|
146 |
-
|
147 |
-
# Generate response
|
148 |
-
inputs = tokenizer(prompt, return_tensors="pt", max_length=max_length, truncation=True)
|
149 |
-
|
150 |
-
with torch.inference_mode():
|
151 |
-
outputs = model.generate(
|
152 |
-
**inputs,
|
153 |
-
max_length=max_length,
|
154 |
-
num_beams=4,
|
155 |
-
temperature=0.7,
|
156 |
-
top_p=0.9,
|
157 |
-
repetition_penalty=1.2,
|
158 |
-
early_stopping=True
|
159 |
-
)
|
160 |
-
|
161 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
162 |
-
|
163 |
-
# Format the response
|
164 |
-
formatted_response = response.replace(". ", ".\n").replace("• ", "\n• ")
|
165 |
-
|
166 |
-
return formatted_response
|
167 |
-
|
168 |
-
# Streamlit App
|
169 |
-
st.title("🧩 AMA Autism")
|
170 |
-
st.write("This app searches through scientific papers to answer your questions about autism. For best results, be specific in your questions.")
|
171 |
-
query = st.text_input("Please ask me anything about autism ✨")
|
172 |
-
|
173 |
-
if query:
|
174 |
-
with st.status("Searching for answers...") as status:
|
175 |
-
# Search for papers
|
176 |
-
df = search_research_papers(query)
|
177 |
-
st.write("Searching for data in PubMed and arXiv...")
|
178 |
-
st.write("Data found!")
|
179 |
-
|
180 |
-
# Get relevant context
|
181 |
-
context = "\n".join([
|
182 |
-
f"{text[:1000]}" for text in df['text'].head(3)
|
183 |
-
])
|
184 |
-
|
185 |
-
# Generate answer
|
186 |
-
answer = generate_answer(query, context)
|
187 |
-
st.write("Generating answer...")
|
188 |
-
status.update(
|
189 |
-
label="Search complete!", state="complete", expanded=False
|
190 |
-
)
|
191 |
-
if answer and not answer.isspace():
|
192 |
-
st.success("Answer found!")
|
193 |
-
st.write(answer)
|
194 |
-
|
195 |
-
st.write("### Sources used:")
|
196 |
-
for _, row in df.head(3).iterrows():
|
197 |
-
st.markdown(f"**[{row['title']}]({row['url']})** ({row['published']})")
|
198 |
-
st.write(f"**Summary:** {row['text'][:200]}...")
|
199 |
-
st.write("---")
|
200 |
-
else:
|
201 |
-
st.warning("I couldn't find a specific answer in the research papers. Try rephrasing your question.")
|
202 |
-
if df.empty:
|
203 |
-
st.warning("I couldn't find any relevant research papers about this topic. Please try rephrasing your question or ask something else about autism.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -17,7 +17,7 @@ DATA_DIR = "/data" if os.path.exists("/data") else "."
|
|
17 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
18 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
19 |
TOKENIZER_MODEL = "google/flan-t5-small"
|
20 |
-
SUMMARIZATION_MODEL= "
|
21 |
# SUMMARIZATION_MODEL="rhaymison/t5-portuguese-small-summarization"
|
22 |
|
23 |
@st.cache_resource
|
@@ -211,26 +211,25 @@ def generate_answer(question, context, max_length=512):
|
|
211 |
|
212 |
# Format the input for T5 (it expects a specific format)
|
213 |
input_text = f"""Objective:
|
214 |
-
|
215 |
|
216 |
Question: {clean_question}
|
217 |
-
|
218 |
Research Papers:
|
219 |
{clean_context}
|
220 |
|
221 |
Instructions:
|
222 |
-
|
223 |
-
|
224 |
-
- Clearly define what autism is in an easy-to-understand way, avoiding overly complex technical terms.
|
225 |
- Use real-life examples
|
226 |
-
-
|
227 |
-
-
|
228 |
-
- Instead of just
|
229 |
-
- Avoid
|
230 |
-
- If a
|
231 |
-
-
|
232 |
-
-
|
233 |
-
|
|
|
234 |
|
235 |
try:
|
236 |
# T5 expects a specific format for the input
|
|
|
17 |
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
18 |
DATASET_PATH = os.path.join(DATASET_DIR, "dataset")
|
19 |
TOKENIZER_MODEL = "google/flan-t5-small"
|
20 |
+
SUMMARIZATION_MODEL= "HuggingFaceTB/SmolVLM-256M-Instruct"
|
21 |
# SUMMARIZATION_MODEL="rhaymison/t5-portuguese-small-summarization"
|
22 |
|
23 |
@st.cache_resource
|
|
|
211 |
|
212 |
# Format the input for T5 (it expects a specific format)
|
213 |
input_text = f"""Objective:
|
214 |
+
Provide a clear, simple, and well-structured answer about autism that is easy to understand for a general audience. Use the provided research papers as references.
|
215 |
|
216 |
Question: {clean_question}
|
|
|
217 |
Research Papers:
|
218 |
{clean_context}
|
219 |
|
220 |
Instructions:
|
221 |
+
Start with a simple definition
|
222 |
+
- Explain what autism is in a short and clear way, avoiding technical terms.
|
|
|
223 |
- Use real-life examples
|
224 |
+
- Give practical and relatable examples to help illustrate key points.
|
225 |
+
- Explain research in simple words
|
226 |
+
- Instead of just citing studies, summarize their key findings in a way that anyone can understand. Example: "A study from X University found that..."
|
227 |
+
- Avoid complex words
|
228 |
+
- If a scientific term is needed, provide a short and simple explanation.
|
229 |
+
- Use clear formatting
|
230 |
+
- Write in short paragraphs, bullet points, or numbered lists to improve readability.
|
231 |
+
- Keep a friendly tone
|
232 |
+
- Make the response engaging and easy to follow, so people without prior knowledge can understand."""
|
233 |
|
234 |
try:
|
235 |
# T5 expects a specific format for the input
|
faiss_index/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
# This file makes the faiss_index directory a Python package
|
|
|
|
faiss_index/index.py
DELETED
@@ -1,232 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import faiss
|
3 |
-
import arxiv
|
4 |
-
from datasets import Dataset
|
5 |
-
import os
|
6 |
-
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
|
7 |
-
import torch
|
8 |
-
import logging
|
9 |
-
import requests
|
10 |
-
from datetime import datetime
|
11 |
-
import xml.etree.ElementTree as ET
|
12 |
-
from time import sleep
|
13 |
-
|
14 |
-
# Configure logging
|
15 |
-
logging.basicConfig(level=logging.INFO)
|
16 |
-
|
17 |
-
# Define data paths
|
18 |
-
DATA_DIR = os.getenv("DATA_DIR", "/data" if os.path.exists("/data") else ".")
|
19 |
-
DATASET_DIR = os.path.join(DATA_DIR, "rag_dataset")
|
20 |
-
|
21 |
-
def fetch_arxiv_papers(query, max_results=10):
|
22 |
-
"""Fetch papers from arXiv and format them for RAG"""
|
23 |
-
client = arxiv.Client()
|
24 |
-
|
25 |
-
# Clean and prepare the search query
|
26 |
-
query = query.replace('and', '').strip()
|
27 |
-
terms = [term.strip() for term in query.split() if term.strip()]
|
28 |
-
|
29 |
-
# Always include autism in the search
|
30 |
-
if 'autism' not in [t.lower() for t in terms]:
|
31 |
-
terms.insert(0, 'autism')
|
32 |
-
|
33 |
-
# Create search query with required autism term
|
34 |
-
term_queries = []
|
35 |
-
for term in terms:
|
36 |
-
if term.lower() != "autism":
|
37 |
-
term_queries.append(f'abs:"{term}" OR ti:"{term}"')
|
38 |
-
|
39 |
-
search_query = '(abs:"autism" OR ti:"autism")'
|
40 |
-
if term_queries:
|
41 |
-
search_query += f' AND ({" OR ".join(term_queries)})'
|
42 |
-
search_query = f'({search_query}) AND (cat:q-bio* OR cat:med*)'
|
43 |
-
|
44 |
-
logging.info(f"Searching arXiv with query: {search_query}")
|
45 |
-
|
46 |
-
search = arxiv.Search(
|
47 |
-
query=search_query,
|
48 |
-
max_results=max_results * 2,
|
49 |
-
sort_by=arxiv.SortCriterion.Relevance
|
50 |
-
)
|
51 |
-
|
52 |
-
try:
|
53 |
-
results = list(client.results(search))
|
54 |
-
papers = []
|
55 |
-
|
56 |
-
for i, result in enumerate(results):
|
57 |
-
# Only include papers that mention autism
|
58 |
-
text = (result.title + " " + result.summary).lower()
|
59 |
-
if 'autism' in text:
|
60 |
-
papers.append({
|
61 |
-
"id": f"arxiv_{i}",
|
62 |
-
"text": result.summary,
|
63 |
-
"title": result.title,
|
64 |
-
"url": result.entry_id,
|
65 |
-
"published": result.published.strftime("%Y-%m-%d")
|
66 |
-
})
|
67 |
-
if len(papers) >= max_results:
|
68 |
-
break
|
69 |
-
|
70 |
-
logging.info(f"Found {len(papers)} relevant papers about autism from arXiv")
|
71 |
-
return papers
|
72 |
-
except Exception as e:
|
73 |
-
logging.error(f"Error fetching papers from arXiv: {str(e)}")
|
74 |
-
return []
|
75 |
-
|
76 |
-
def fetch_pubmed_papers(query, max_results=10):
|
77 |
-
"""Fetch papers from PubMed using E-utilities"""
|
78 |
-
base_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
|
79 |
-
|
80 |
-
# Search for papers
|
81 |
-
search_url = f"{base_url}/esearch.fcgi"
|
82 |
-
search_params = {
|
83 |
-
'db': 'pubmed',
|
84 |
-
'term': f"{query} AND autism",
|
85 |
-
'retmax': max_results,
|
86 |
-
'sort': 'relevance',
|
87 |
-
'retmode': 'xml'
|
88 |
-
}
|
89 |
-
|
90 |
-
try:
|
91 |
-
# Get paper IDs
|
92 |
-
response = requests.get(search_url, params=search_params)
|
93 |
-
root = ET.fromstring(response.content)
|
94 |
-
id_list = [id_elem.text for id_elem in root.findall('.//Id')]
|
95 |
-
|
96 |
-
if not id_list:
|
97 |
-
return []
|
98 |
-
|
99 |
-
# Fetch paper details
|
100 |
-
fetch_url = f"{base_url}/efetch.fcgi"
|
101 |
-
fetch_params = {
|
102 |
-
'db': 'pubmed',
|
103 |
-
'id': ','.join(id_list),
|
104 |
-
'retmode': 'xml'
|
105 |
-
}
|
106 |
-
|
107 |
-
response = requests.get(fetch_url, params=fetch_params)
|
108 |
-
root = ET.fromstring(response.content)
|
109 |
-
papers = []
|
110 |
-
|
111 |
-
for article in root.findall('.//PubmedArticle'):
|
112 |
-
try:
|
113 |
-
# Extract article information
|
114 |
-
title = article.find('.//ArticleTitle').text
|
115 |
-
abstract = article.find('.//Abstract/AbstractText')
|
116 |
-
abstract = abstract.text if abstract is not None else ""
|
117 |
-
|
118 |
-
if 'autism' in (title + abstract).lower():
|
119 |
-
pmid = article.find('.//PMID').text
|
120 |
-
date = article.find('.//PubDate')
|
121 |
-
year = date.find('Year').text if date.find('Year') is not None else "Unknown"
|
122 |
-
|
123 |
-
papers.append({
|
124 |
-
"id": f"pubmed_{pmid}",
|
125 |
-
"text": abstract,
|
126 |
-
"title": title,
|
127 |
-
"url": f"https://pubmed.ncbi.nlm.nih.gov/{pmid}/",
|
128 |
-
"published": year
|
129 |
-
})
|
130 |
-
except Exception as e:
|
131 |
-
logging.warning(f"Error processing PubMed article: {str(e)}")
|
132 |
-
continue
|
133 |
-
|
134 |
-
logging.info(f"Found {len(papers)} relevant papers from PubMed")
|
135 |
-
return papers
|
136 |
-
|
137 |
-
except Exception as e:
|
138 |
-
logging.error(f"Error fetching papers from PubMed: {str(e)}")
|
139 |
-
return []
|
140 |
-
|
141 |
-
def fetch_papers(query, max_results=10):
|
142 |
-
"""Fetch papers from both arXiv and PubMed"""
|
143 |
-
arxiv_papers = fetch_arxiv_papers(query, max_results=max_results)
|
144 |
-
sleep(1) # Respect rate limits
|
145 |
-
pubmed_papers = fetch_pubmed_papers(query, max_results=max_results)
|
146 |
-
|
147 |
-
# Combine and deduplicate papers based on title similarity
|
148 |
-
all_papers = arxiv_papers + pubmed_papers
|
149 |
-
unique_papers = []
|
150 |
-
seen_titles = set()
|
151 |
-
|
152 |
-
for paper in all_papers:
|
153 |
-
title_lower = paper['title'].lower()
|
154 |
-
if not any(title_lower in seen_title or seen_title in title_lower for seen_title in seen_titles):
|
155 |
-
unique_papers.append(paper)
|
156 |
-
seen_titles.add(title_lower)
|
157 |
-
|
158 |
-
# Sort by relevance (papers with 'autism' in title first)
|
159 |
-
unique_papers.sort(key=lambda x: 'autism' in x['title'].lower(), reverse=True)
|
160 |
-
return unique_papers[:max_results]
|
161 |
-
|
162 |
-
def build_faiss_index(papers, dataset_dir=DATASET_DIR):
|
163 |
-
"""Build and save dataset with FAISS index for RAG"""
|
164 |
-
if not papers:
|
165 |
-
logging.warning("No papers found. Creating empty dataset.")
|
166 |
-
# Create an empty dataset with the expected structure
|
167 |
-
dataset = Dataset.from_dict({
|
168 |
-
"text": [],
|
169 |
-
"embeddings": [],
|
170 |
-
"title": []
|
171 |
-
})
|
172 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
173 |
-
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
174 |
-
return dataset_dir
|
175 |
-
|
176 |
-
# Initialize smaller DPR encoder
|
177 |
-
ctx_encoder = DPRContextEncoder.from_pretrained(
|
178 |
-
"facebook/dpr-ctx_encoder-single-nq-base",
|
179 |
-
torch_dtype=torch.float16,
|
180 |
-
low_cpu_mem_usage=True
|
181 |
-
)
|
182 |
-
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
|
183 |
-
|
184 |
-
# Create embeddings with smaller batches and memory optimization
|
185 |
-
texts = [p["text"] for p in papers]
|
186 |
-
embeddings = []
|
187 |
-
batch_size = 4 # Smaller batch size
|
188 |
-
|
189 |
-
with torch.inference_mode():
|
190 |
-
for i in range(0, len(texts), batch_size):
|
191 |
-
batch_texts = texts[i:i + batch_size]
|
192 |
-
inputs = ctx_tokenizer(
|
193 |
-
batch_texts,
|
194 |
-
max_length=256, # Reduced from default
|
195 |
-
padding=True,
|
196 |
-
truncation=True,
|
197 |
-
return_tensors="pt"
|
198 |
-
)
|
199 |
-
outputs = ctx_encoder(**inputs)
|
200 |
-
embeddings.extend(outputs.pooler_output.cpu().numpy())
|
201 |
-
|
202 |
-
# Clear memory
|
203 |
-
del outputs
|
204 |
-
if torch.cuda.is_available():
|
205 |
-
torch.cuda.empty_cache()
|
206 |
-
|
207 |
-
# Convert to numpy array and build FAISS index
|
208 |
-
embeddings = np.array(embeddings, dtype=np.float32) # Ensure float32 type
|
209 |
-
dimension = embeddings.shape[1]
|
210 |
-
|
211 |
-
# Normalize the vectors manually
|
212 |
-
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
|
213 |
-
embeddings = embeddings / norms
|
214 |
-
|
215 |
-
# Create FAISS index
|
216 |
-
index = faiss.IndexFlatIP(dimension)
|
217 |
-
index.add(embeddings)
|
218 |
-
|
219 |
-
# Create and save the dataset
|
220 |
-
dataset = Dataset.from_dict({
|
221 |
-
"text": texts,
|
222 |
-
"embeddings": embeddings.tolist(), # Convert to list for storage
|
223 |
-
"title": [p["title"] for p in papers]
|
224 |
-
})
|
225 |
-
|
226 |
-
# Create directory if it doesn't exist
|
227 |
-
os.makedirs(dataset_dir, exist_ok=True)
|
228 |
-
|
229 |
-
# Save dataset
|
230 |
-
dataset.save_to_disk(os.path.join(dataset_dir, "dataset"))
|
231 |
-
logging.info(f"Dataset saved to {dataset_dir}")
|
232 |
-
return dataset_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|