Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
7177172
1
Parent(s):
d5b5c7a
Refactor embedding model integration and update find_synergistic_papers calls to use dataset parameter
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ from arxiv_stuff import ARXIV_CATEGORIES_FLAT
|
|
18 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
19 |
|
20 |
# Login to Hugging Face Hub
|
21 |
-
hf_hub_login(token=HF_TOKEN, add_to_git_credential=True)
|
22 |
|
23 |
# Check if using persistent storage
|
24 |
persistent_storage = os.path.exists("/data")
|
@@ -142,6 +142,16 @@ def init_embedding_model(model_name_or_path: str, model_revision: str = None) ->
|
|
142 |
)
|
143 |
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
def init_reasoning_model(model_name: str) -> InferenceClient:
|
146 |
global reasoning_model
|
147 |
reasoning_model = InferenceClient(
|
@@ -271,16 +281,6 @@ def generate(messages: list[dict[str, str]]) -> str:
|
|
271 |
return output
|
272 |
|
273 |
|
274 |
-
@spaces.GPU
|
275 |
-
def embed_text(text: str | list[str]) -> torch.Tensor:
|
276 |
-
global embedding_model
|
277 |
-
|
278 |
-
# Strip any leading/trailing whitespace
|
279 |
-
text = text.strip() if isinstance(text, str) else [t.strip() for t in text]
|
280 |
-
embed_text = embedding_model.encode(text, normalize_embeddings=True) # Ensure vectors are normalized
|
281 |
-
return embed_text
|
282 |
-
|
283 |
-
|
284 |
def analyse_abstracts(query_abstract: str, compare_abstract: dict) -> str:
|
285 |
"""Analyze the relationship between two abstracts and return formatted analysis"""
|
286 |
global reasoning_model
|
@@ -464,7 +464,7 @@ def find_synergistic_papers(abstract: str, limit=25) -> list[dict]:
|
|
464 |
def format_search_results_json(abstract: str) -> str:
|
465 |
"""Format search results as JSON for display"""
|
466 |
# Find papers synergistic with the given abstract
|
467 |
-
papers = find_synergistic_papers(abstract)
|
468 |
|
469 |
# Convert to JSON for display
|
470 |
json_output = json.dumps(papers, indent=2)
|
@@ -475,7 +475,8 @@ def format_search_results_json(abstract: str) -> str:
|
|
475 |
def format_search_results(abstract: str) -> tuple[pd.DataFrame, list[dict]]:
|
476 |
"""Format search results as a DataFrame for display"""
|
477 |
# Find papers synergistic with the given abstract
|
478 |
-
papers = find_synergistic_papers(abstract)
|
|
|
479 |
|
480 |
# Convert to DataFrame for display
|
481 |
df = pd.DataFrame(
|
|
|
18 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
19 |
|
20 |
# Login to Hugging Face Hub
|
21 |
+
# hf_hub_login(token=HF_TOKEN, add_to_git_credential=True)
|
22 |
|
23 |
# Check if using persistent storage
|
24 |
persistent_storage = os.path.exists("/data")
|
|
|
142 |
)
|
143 |
|
144 |
|
145 |
+
@spaces.GPU
|
146 |
+
def embed_text(text: str | list[str]) -> torch.Tensor:
|
147 |
+
global embedding_model
|
148 |
+
|
149 |
+
# Strip any leading/trailing whitespace
|
150 |
+
text = text.strip() if isinstance(text, str) else [t.strip() for t in text]
|
151 |
+
embed_text = embedding_model.encode(text, normalize_embeddings=True) # Ensure vectors are normalized
|
152 |
+
return embed_text
|
153 |
+
|
154 |
+
|
155 |
def init_reasoning_model(model_name: str) -> InferenceClient:
|
156 |
global reasoning_model
|
157 |
reasoning_model = InferenceClient(
|
|
|
281 |
return output
|
282 |
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
def analyse_abstracts(query_abstract: str, compare_abstract: dict) -> str:
|
285 |
"""Analyze the relationship between two abstracts and return formatted analysis"""
|
286 |
global reasoning_model
|
|
|
464 |
def format_search_results_json(abstract: str) -> str:
|
465 |
"""Format search results as JSON for display"""
|
466 |
# Find papers synergistic with the given abstract
|
467 |
+
papers = embedding_model.find_synergistic_papers(dataset, abstract)
|
468 |
|
469 |
# Convert to JSON for display
|
470 |
json_output = json.dumps(papers, indent=2)
|
|
|
475 |
def format_search_results(abstract: str) -> tuple[pd.DataFrame, list[dict]]:
|
476 |
"""Format search results as a DataFrame for display"""
|
477 |
# Find papers synergistic with the given abstract
|
478 |
+
# papers = embedding_model.find_synergistic_papers(abstract)
|
479 |
+
papers = embedding_model.find_synergistic_papers(dataset, abstract)
|
480 |
|
481 |
# Convert to DataFrame for display
|
482 |
df = pd.DataFrame(
|