Spaces:
Sleeping
Sleeping
Upload 7 files
Browse files- README.md +3 -8
- __init__.py +1 -0
- app.py +60 -0
- model.py +80 -0
- rag.configs.yml +7 -0
- requirements.txt +178 -0
- search.py +148 -0
README.md
CHANGED
@@ -1,13 +1,8 @@
|
|
1 |
---
|
2 |
-
title: NBPlatina
|
3 |
-
emoji: 🐢
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: red
|
6 |
sdk: streamlit
|
7 |
-
sdk_version: 1.
|
8 |
app_file: app.py
|
9 |
-
|
10 |
-
license: mit
|
11 |
---
|
12 |
|
13 |
-
|
|
|
1 |
---
|
|
|
|
|
|
|
|
|
2 |
sdk: streamlit
|
3 |
+
sdk_version: 1.33.0
|
4 |
app_file: app.py
|
5 |
+
licese: mit
|
|
|
6 |
---
|
7 |
|
8 |
+
install nltk.download("punkt")
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
import search
|
app.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from search import SemanticSearch, GoogleSearch, Document
|
2 |
+
import streamlit as st
|
3 |
+
from model import RAGModel, load_configs
|
4 |
+
|
5 |
+
|
6 |
+
def run_on_start():
|
7 |
+
|
8 |
+
if "configs" not in st.session_state:
|
9 |
+
st.session_state.configs = configs = load_configs(config_file="rag.configs.yml")
|
10 |
+
if "model" not in st.session_state:
|
11 |
+
st.session_state.model = RAGModel(configs)
|
12 |
+
|
13 |
+
run_on_start()
|
14 |
+
|
15 |
+
|
16 |
+
def search(query):
|
17 |
+
g = GoogleSearch(query)
|
18 |
+
data = g.all_page_data
|
19 |
+
d = Document(data, min_char_len=st.session_state.configs["document"]["min_char_length"])
|
20 |
+
st.session_state.doc = d.doc()
|
21 |
+
|
22 |
+
|
23 |
+
st.title("Search Here Instead of Google")
|
24 |
+
|
25 |
+
if "messages" not in st.session_state:
|
26 |
+
st.session_state.messages = []
|
27 |
+
|
28 |
+
if "doc" not in st.session_state:
|
29 |
+
st.session_state.doc = None
|
30 |
+
|
31 |
+
if "refresh" not in st.session_state:
|
32 |
+
st.session_state.refresh = True
|
33 |
+
|
34 |
+
for message in st.session_state.messages:
|
35 |
+
with st.chat_message(message["role"]):
|
36 |
+
st.markdown(message["content"])
|
37 |
+
|
38 |
+
|
39 |
+
if prompt := st.chat_input("Search Here insetad of Google"):
|
40 |
+
st.chat_message("user").markdown(prompt)
|
41 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
42 |
+
|
43 |
+
configs = st.session_state.configs
|
44 |
+
if st.session_state.refresh:
|
45 |
+
st.session_state.refresh = False
|
46 |
+
search(prompt)
|
47 |
+
|
48 |
+
s = SemanticSearch(
|
49 |
+
st.session_state.doc,
|
50 |
+
configs["model"]["embeding_model"],
|
51 |
+
configs["model"]["device"],
|
52 |
+
)
|
53 |
+
topk, u = s.semantic_search(query=prompt, k=32)
|
54 |
+
output = st.session_state.model.answer_query(query=prompt, topk_items=topk)
|
55 |
+
response = output
|
56 |
+
with st.chat_message("assistant"):
|
57 |
+
st.markdown(response)
|
58 |
+
|
59 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
60 |
+
|
model.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from search import SemanticSearch, GoogleSearch, Document
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
+
from transformers import BitsAndBytesConfig
|
4 |
+
from transformers.utils import is_flash_attn_2_available
|
5 |
+
import yaml
|
6 |
+
import torch
|
7 |
+
import nltk
|
8 |
+
|
9 |
+
def load_configs(config_file: str) -> dict:
|
10 |
+
with open(config_file, "r") as f:
|
11 |
+
configs = yaml.safe_load(f)
|
12 |
+
|
13 |
+
return configs
|
14 |
+
|
15 |
+
|
16 |
+
class RAGModel:
|
17 |
+
def __init__(self, configs) -> None:
|
18 |
+
self.configs = configs
|
19 |
+
self.device = configs["model"]["device"]
|
20 |
+
model_url = configs["model"]["genration_model"]
|
21 |
+
# quantization_config = BitsAndBytesConfig(
|
22 |
+
# load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
|
23 |
+
# )
|
24 |
+
|
25 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
26 |
+
model_url,
|
27 |
+
torch_dtype=torch.float16,
|
28 |
+
# quantization_config=quantization_config,
|
29 |
+
low_cpu_mem_usage=False,
|
30 |
+
attn_implementation="sdpa",
|
31 |
+
).to(self.device)
|
32 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
33 |
+
model_url,
|
34 |
+
)
|
35 |
+
|
36 |
+
def create_prompt(self, query, topk_items: list[str]):
|
37 |
+
|
38 |
+
context = "\n-".join(c for c in topk_items)
|
39 |
+
|
40 |
+
base_prompt = f"""You are an alternate to goole search. Your job is to answer the user query in as detailed manner as possible.
|
41 |
+
you have access to the internet and other relevent data related to the user's question.
|
42 |
+
Give time for yourself to read the context and user query and extract relevent data and then answer the query.
|
43 |
+
make sure your answers is as detailed as posssbile.
|
44 |
+
Do not return thinking process, just return the answer.
|
45 |
+
Give the output structured as a Wikipedia article.
|
46 |
+
Now use the following context items to answer the user query
|
47 |
+
context: {context}
|
48 |
+
user query : {query}
|
49 |
+
"""
|
50 |
+
|
51 |
+
dialog_template = [{"role": "user", "content": base_prompt}]
|
52 |
+
|
53 |
+
prompt = self.tokenizer.apply_chat_template(
|
54 |
+
conversation=dialog_template, tokenize=False, add_feneration_prompt=True
|
55 |
+
)
|
56 |
+
return prompt
|
57 |
+
|
58 |
+
def answer_query(self, query: str, topk_items: list[str]):
|
59 |
+
|
60 |
+
prompt = self.create_prompt(query, topk_items)
|
61 |
+
input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)
|
62 |
+
output = self.model.generate(**input_ids, temperature=0.7, max_new_tokens=512, do_sample=True)
|
63 |
+
text = self.tokenizer.decode(output[0])
|
64 |
+
text = text.replace(prompt, "").replace("<bos>", "").replace("<eos>", "")
|
65 |
+
|
66 |
+
|
67 |
+
return text
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
configs = load_configs(config_file="rag.configs.yml")
|
71 |
+
query = "The height of burj khalifa is 1000 meters and it was built in 2023. What is the height of burgj khalifa"
|
72 |
+
# g = GoogleSearch(query)
|
73 |
+
# data = g.all_page_data
|
74 |
+
# d = Document(data, 512)
|
75 |
+
# doc_chunks = d.doc()
|
76 |
+
# s = SemanticSearch(doc_chunks, "all-mpnet-base-v2", "mps")
|
77 |
+
# topk, u = s.semantic_search(query=query, k=32)
|
78 |
+
r = RAGModel(configs)
|
79 |
+
output = r.answer_query(query=query, topk_items=[""])
|
80 |
+
print(output)
|
rag.configs.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
document:
|
2 |
+
min_char_length: 512
|
3 |
+
|
4 |
+
model:
|
5 |
+
embeding_model: all-mpnet-base-v2
|
6 |
+
genration_model: google/gemma-7b-it
|
7 |
+
device : cuda
|
requirements.txt
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.29.2
|
2 |
+
albumentations==1.4.3
|
3 |
+
altair==5.3.0
|
4 |
+
attrs==23.2.0
|
5 |
+
beautifulsoup4==4.12.3
|
6 |
+
bitsandbytes==0.43.1
|
7 |
+
blinker==1.7.0
|
8 |
+
cachetools==5.3.3
|
9 |
+
certifi==2024.2.2
|
10 |
+
charset-normalizer==2.0.4
|
11 |
+
click==8.1.7
|
12 |
+
colorama==0.4.6
|
13 |
+
contourpy==1.2.1
|
14 |
+
cycler==0.12.1
|
15 |
+
filelock==3.13.1
|
16 |
+
fonttools==4.50.0
|
17 |
+
fsspec==2024.3.1
|
18 |
+
gitdb==4.0.11
|
19 |
+
GitPython==3.1.43
|
20 |
+
gmpy2==2.1.2
|
21 |
+
huggingface-hub==0.22.2
|
22 |
+
idna==3.4
|
23 |
+
imageio==2.34.0
|
24 |
+
importlib_resources==6.4.0
|
25 |
+
Jinja2==3.1.3
|
26 |
+
joblib==1.3.2
|
27 |
+
jsonschema==4.21.1
|
28 |
+
jsonschema-specifications==2023.12.1
|
29 |
+
kiwisolver==1.4.5
|
30 |
+
lazy_loader==0.4
|
31 |
+
markdown-it-py==3.0.0
|
32 |
+
MarkupSafe==2.1.3
|
33 |
+
matplotlib==3.8.4
|
34 |
+
mdurl==0.1.2
|
35 |
+
mkl-fft==1.3.8
|
36 |
+
mkl-random==1.2.4
|
37 |
+
mkl-service==2.4.0
|
38 |
+
mpmath==1.3.0
|
39 |
+
networkx==3.1
|
40 |
+
nltk==3.8.1
|
41 |
+
numpy==1.26.4
|
42 |
+
opencv-python-headless==4.9.0.80
|
43 |
+
packaging==24.0
|
44 |
+
pandas==2.2.2
|
45 |
+
pillow==10.2.0
|
46 |
+
pip==23.3.1
|
47 |
+
protobuf==4.25.3
|
48 |
+
psutil==5.9.8
|
49 |
+
pyarrow==16.0.0
|
50 |
+
pydeck==0.8.1b0
|
51 |
+
Pygments==2.17.2
|
52 |
+
pyparsing==3.1.2
|
53 |
+
python-dateutil==2.9.0.post0
|
54 |
+
pytz==2024.1
|
55 |
+
PyYAML==6.0.1
|
56 |
+
referencing==0.34.0
|
57 |
+
regex==2024.4.16
|
58 |
+
requests==2.31.0
|
59 |
+
rich==13.7.1
|
60 |
+
rpds-py==0.18.0
|
61 |
+
safetensors==0.4.3
|
62 |
+
scikit-image==0.22.0
|
63 |
+
scikit-learn==1.4.1.post1
|
64 |
+
scipy==1.13.0
|
65 |
+
sentence-transformers==2.7.0
|
66 |
+
setuptools==68.2.2
|
67 |
+
six==1.16.0
|
68 |
+
smmap==5.0.1
|
69 |
+
soupsieve==2.5
|
70 |
+
streamlit==1.33.0
|
71 |
+
sympy==1.12
|
72 |
+
tenacity==8.2.3
|
73 |
+
threadpoolctl==3.4.0
|
74 |
+
tifffile==2024.2.12
|
75 |
+
tokenizers==0.15.2
|
76 |
+
toml==0.10.2
|
77 |
+
toolz==0.12.1
|
78 |
+
torch==2.2.2
|
79 |
+
torchaudio==2.2.2
|
80 |
+
torchvision==0.17.2
|
81 |
+
tornado==6.4
|
82 |
+
tqdm==4.66.2
|
83 |
+
transformers==4.39.3
|
84 |
+
typing_extensions==4.9.0
|
85 |
+
tzdata==2024.1
|
86 |
+
urllib3==2.1.0
|
87 |
+
watchdog==4.0.0
|
88 |
+
wheel==0.41.2
|
89 |
+
zipp==3.18.1
|
90 |
+
accelerate==0.29.2
|
91 |
+
albumentations==1.4.3
|
92 |
+
altair==5.3.0
|
93 |
+
attrs==23.2.0
|
94 |
+
beautifulsoup4==4.12.3
|
95 |
+
bitsandbytes==0.43.1
|
96 |
+
blinker==1.7.0
|
97 |
+
cachetools==5.3.3
|
98 |
+
certifi==2024.2.2
|
99 |
+
charset-normalizer==2.0.4
|
100 |
+
click==8.1.7
|
101 |
+
colorama==0.4.6
|
102 |
+
contourpy==1.2.1
|
103 |
+
cycler==0.12.1
|
104 |
+
filelock==3.13.1
|
105 |
+
fonttools==4.50.0
|
106 |
+
fsspec==2024.3.1
|
107 |
+
gitdb==4.0.11
|
108 |
+
GitPython==3.1.43
|
109 |
+
gmpy2==2.1.2
|
110 |
+
huggingface-hub==0.22.2
|
111 |
+
idna==3.4
|
112 |
+
imageio==2.34.0
|
113 |
+
importlib_resources==6.4.0
|
114 |
+
Jinja2==3.1.3
|
115 |
+
joblib==1.3.2
|
116 |
+
jsonschema==4.21.1
|
117 |
+
jsonschema-specifications==2023.12.1
|
118 |
+
kiwisolver==1.4.5
|
119 |
+
lazy_loader==0.4
|
120 |
+
markdown-it-py==3.0.0
|
121 |
+
MarkupSafe==2.1.3
|
122 |
+
matplotlib==3.8.4
|
123 |
+
mdurl==0.1.2
|
124 |
+
mkl-fft==1.3.8
|
125 |
+
mkl-random==1.2.4
|
126 |
+
mkl-service==2.4.0
|
127 |
+
mpmath==1.3.0
|
128 |
+
networkx==3.1
|
129 |
+
nltk==3.8.1
|
130 |
+
numpy==1.26.4
|
131 |
+
opencv-python-headless==4.9.0.80
|
132 |
+
packaging==24.0
|
133 |
+
pandas==2.2.2
|
134 |
+
pillow==10.2.0
|
135 |
+
pip==23.3.1
|
136 |
+
protobuf==4.25.3
|
137 |
+
psutil==5.9.8
|
138 |
+
pyarrow==16.0.0
|
139 |
+
pydeck==0.8.1b0
|
140 |
+
Pygments==2.17.2
|
141 |
+
pyparsing==3.1.2
|
142 |
+
python-dateutil==2.9.0.post0
|
143 |
+
pytz==2024.1
|
144 |
+
PyYAML==6.0.1
|
145 |
+
referencing==0.34.0
|
146 |
+
regex==2024.4.16
|
147 |
+
requests==2.31.0
|
148 |
+
rich==13.7.1
|
149 |
+
rpds-py==0.18.0
|
150 |
+
safetensors==0.4.3
|
151 |
+
scikit-image==0.22.0
|
152 |
+
scikit-learn==1.4.1.post1
|
153 |
+
scipy==1.13.0
|
154 |
+
sentence-transformers==2.7.0
|
155 |
+
setuptools==68.2.2
|
156 |
+
six==1.16.0
|
157 |
+
smmap==5.0.1
|
158 |
+
soupsieve==2.5
|
159 |
+
streamlit==1.33.0
|
160 |
+
sympy==1.12
|
161 |
+
tenacity==8.2.3
|
162 |
+
threadpoolctl==3.4.0
|
163 |
+
tifffile==2024.2.12
|
164 |
+
tokenizers==0.15.2
|
165 |
+
toml==0.10.2
|
166 |
+
toolz==0.12.1
|
167 |
+
torch==2.2.2
|
168 |
+
torchaudio==2.2.2
|
169 |
+
torchvision==0.17.2
|
170 |
+
tornado==6.4
|
171 |
+
tqdm==4.66.2
|
172 |
+
transformers==4.39.3
|
173 |
+
typing_extensions==4.9.0
|
174 |
+
tzdata==2024.1
|
175 |
+
urllib3==2.1.0
|
176 |
+
watchdog==4.0.0
|
177 |
+
wheel==0.41.2
|
178 |
+
zipp==3.18.1
|
search.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bs4 import BeautifulSoup
|
2 |
+
import urllib
|
3 |
+
import requests
|
4 |
+
import nltk
|
5 |
+
import torch
|
6 |
+
from typing import Union
|
7 |
+
from sentence_transformers import SentenceTransformer, util
|
8 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
9 |
+
|
10 |
+
|
11 |
+
class GoogleSearch:
|
12 |
+
def __init__(self, query: str) -> None:
|
13 |
+
self.query = query
|
14 |
+
escaped_query = urllib.parse.quote_plus(query)
|
15 |
+
self.URL = f"https://www.google.com/search?q={escaped_query}"
|
16 |
+
|
17 |
+
self.headers = {
|
18 |
+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3538.102 Safari/537.36"
|
19 |
+
}
|
20 |
+
self.links = self.get_initial_links()
|
21 |
+
self.all_page_data = self.all_pages()
|
22 |
+
|
23 |
+
def clean_urls(self, anchors: list[str]) -> list[str]:
|
24 |
+
|
25 |
+
links: list[str] = []
|
26 |
+
for a in anchors:
|
27 |
+
links.append(
|
28 |
+
list(filter(lambda l: l.startswith("url=http"), a["href"].split("&")))
|
29 |
+
)
|
30 |
+
|
31 |
+
links = [
|
32 |
+
link.split("url=")[-1]
|
33 |
+
for sublist in links
|
34 |
+
for link in sublist
|
35 |
+
if len(link) > 0
|
36 |
+
]
|
37 |
+
|
38 |
+
return links
|
39 |
+
|
40 |
+
def read_url_page(self, url: str) -> str:
|
41 |
+
|
42 |
+
response = requests.get(url, headers=self.headers)
|
43 |
+
response.raise_for_status()
|
44 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
45 |
+
return soup.get_text(strip=True)
|
46 |
+
|
47 |
+
def get_initial_links(self) -> list[str]:
|
48 |
+
"""
|
49 |
+
scrape google for the query with keyword based search
|
50 |
+
"""
|
51 |
+
print("Searching Google...")
|
52 |
+
response = requests.get(self.URL, headers=self.headers)
|
53 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
54 |
+
anchors = soup.find_all("a", href=True)
|
55 |
+
return self.clean_urls(anchors)
|
56 |
+
|
57 |
+
def all_pages(self) -> list[tuple[str, str]]:
|
58 |
+
|
59 |
+
data: list[tuple[str, str]] = []
|
60 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
61 |
+
|
62 |
+
future_to_url = {
|
63 |
+
executor.submit(self.read_url_page, url): url for url in self.links
|
64 |
+
}
|
65 |
+
for future in as_completed(future_to_url):
|
66 |
+
url = future_to_url[future]
|
67 |
+
try:
|
68 |
+
output = future.result()
|
69 |
+
data.append((url, output))
|
70 |
+
|
71 |
+
except requests.exceptions.HTTPError as e:
|
72 |
+
print(e)
|
73 |
+
|
74 |
+
# for url in self.links:
|
75 |
+
# try:
|
76 |
+
# data.append((url, self.read_url_page(url)))
|
77 |
+
# except requests.exceptions.HTTPError as e:
|
78 |
+
# print(e)
|
79 |
+
|
80 |
+
return data
|
81 |
+
|
82 |
+
|
83 |
+
class Document:
|
84 |
+
|
85 |
+
def __init__(self, data: list[tuple[str, str]], min_char_len: int) -> None:
|
86 |
+
"""
|
87 |
+
data : list[tuple[str, str]]
|
88 |
+
url and page data
|
89 |
+
"""
|
90 |
+
self.data = data
|
91 |
+
self.min_char_len = min_char_len
|
92 |
+
|
93 |
+
def make_min_len_chunk(self):
|
94 |
+
raise NotImplementedError
|
95 |
+
|
96 |
+
def chunk_page(
|
97 |
+
self,
|
98 |
+
page_text: str,
|
99 |
+
) -> list[str]:
|
100 |
+
|
101 |
+
min_len_chunks: list[str] = []
|
102 |
+
chunk_text = nltk.tokenize.sent_tokenize(page_text)
|
103 |
+
sentence: str = ""
|
104 |
+
for sent in chunk_text:
|
105 |
+
if len(sentence) > self.min_char_len:
|
106 |
+
min_len_chunks.append(sentence)
|
107 |
+
sent = ""
|
108 |
+
sentence = ""
|
109 |
+
else:
|
110 |
+
sentence += sent
|
111 |
+
return min_len_chunks
|
112 |
+
|
113 |
+
def doc(self) -> tuple[list[str], list[str]]:
|
114 |
+
print("Creating Document...")
|
115 |
+
chunked_data: list[str] = []
|
116 |
+
urls: list[str] = []
|
117 |
+
for url, dataitem in self.data:
|
118 |
+
data = self.chunk_page(dataitem)
|
119 |
+
chunked_data.append(data)
|
120 |
+
urls.append(url)
|
121 |
+
|
122 |
+
chunked_data = [chunk for sublist in chunked_data for chunk in sublist]
|
123 |
+
return chunked_data, url
|
124 |
+
|
125 |
+
|
126 |
+
class SemanticSearch:
|
127 |
+
def __init__(
|
128 |
+
self, doc_chunks: tuple[list, list], model_path: str, device: str
|
129 |
+
) -> None:
|
130 |
+
|
131 |
+
self.doc_chunks, self.urls = doc_chunks
|
132 |
+
self.st = SentenceTransformer(
|
133 |
+
model_path,
|
134 |
+
device,
|
135 |
+
)
|
136 |
+
|
137 |
+
def semantic_search(self, query: str, k: int = 10):
|
138 |
+
print("Searching Top k in document...")
|
139 |
+
query_embeding = self.get_embeding(query)
|
140 |
+
doc_embeding = self.get_embeding(self.doc_chunks)
|
141 |
+
scores = util.dot_score(a=query_embeding, b=doc_embeding)[0]
|
142 |
+
|
143 |
+
top_k = torch.topk(scores, k=k)[1].cpu().tolist()
|
144 |
+
return [self.doc_chunks[i] for i in top_k], self.urls
|
145 |
+
|
146 |
+
def get_embeding(self, text: Union[list[str], str]):
|
147 |
+
en = self.st.encode(text)
|
148 |
+
return en
|