nomadicsynth commited on
Commit
7177172
·
1 Parent(s): d5b5c7a

Refactor embedding model integration and update find_synergistic_papers calls to use dataset parameter

Browse files
Files changed (1) hide show
  1. app.py +14 -13
app.py CHANGED
@@ -18,7 +18,7 @@ from arxiv_stuff import ARXIV_CATEGORIES_FLAT
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
  # Login to Hugging Face Hub
21
- hf_hub_login(token=HF_TOKEN, add_to_git_credential=True)
22
 
23
  # Check if using persistent storage
24
  persistent_storage = os.path.exists("/data")
@@ -142,6 +142,16 @@ def init_embedding_model(model_name_or_path: str, model_revision: str = None) ->
142
  )
143
 
144
 
 
 
 
 
 
 
 
 
 
 
145
  def init_reasoning_model(model_name: str) -> InferenceClient:
146
  global reasoning_model
147
  reasoning_model = InferenceClient(
@@ -271,16 +281,6 @@ def generate(messages: list[dict[str, str]]) -> str:
271
  return output
272
 
273
 
274
- @spaces.GPU
275
- def embed_text(text: str | list[str]) -> torch.Tensor:
276
- global embedding_model
277
-
278
- # Strip any leading/trailing whitespace
279
- text = text.strip() if isinstance(text, str) else [t.strip() for t in text]
280
- embed_text = embedding_model.encode(text, normalize_embeddings=True) # Ensure vectors are normalized
281
- return embed_text
282
-
283
-
284
  def analyse_abstracts(query_abstract: str, compare_abstract: dict) -> str:
285
  """Analyze the relationship between two abstracts and return formatted analysis"""
286
  global reasoning_model
@@ -464,7 +464,7 @@ def find_synergistic_papers(abstract: str, limit=25) -> list[dict]:
464
  def format_search_results_json(abstract: str) -> str:
465
  """Format search results as JSON for display"""
466
  # Find papers synergistic with the given abstract
467
- papers = find_synergistic_papers(abstract)
468
 
469
  # Convert to JSON for display
470
  json_output = json.dumps(papers, indent=2)
@@ -475,7 +475,8 @@ def format_search_results_json(abstract: str) -> str:
475
  def format_search_results(abstract: str) -> tuple[pd.DataFrame, list[dict]]:
476
  """Format search results as a DataFrame for display"""
477
  # Find papers synergistic with the given abstract
478
- papers = find_synergistic_papers(abstract)
 
479
 
480
  # Convert to DataFrame for display
481
  df = pd.DataFrame(
 
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
  # Login to Hugging Face Hub
21
+ # hf_hub_login(token=HF_TOKEN, add_to_git_credential=True)
22
 
23
  # Check if using persistent storage
24
  persistent_storage = os.path.exists("/data")
 
142
  )
143
 
144
 
145
+ @spaces.GPU
146
+ def embed_text(text: str | list[str]) -> torch.Tensor:
147
+ global embedding_model
148
+
149
+ # Strip any leading/trailing whitespace
150
+ text = text.strip() if isinstance(text, str) else [t.strip() for t in text]
151
+ embed_text = embedding_model.encode(text, normalize_embeddings=True) # Ensure vectors are normalized
152
+ return embed_text
153
+
154
+
155
  def init_reasoning_model(model_name: str) -> InferenceClient:
156
  global reasoning_model
157
  reasoning_model = InferenceClient(
 
281
  return output
282
 
283
 
 
 
 
 
 
 
 
 
 
 
284
  def analyse_abstracts(query_abstract: str, compare_abstract: dict) -> str:
285
  """Analyze the relationship between two abstracts and return formatted analysis"""
286
  global reasoning_model
 
464
  def format_search_results_json(abstract: str) -> str:
465
  """Format search results as JSON for display"""
466
  # Find papers synergistic with the given abstract
467
+ papers = embedding_model.find_synergistic_papers(dataset, abstract)
468
 
469
  # Convert to JSON for display
470
  json_output = json.dumps(papers, indent=2)
 
475
  def format_search_results(abstract: str) -> tuple[pd.DataFrame, list[dict]]:
476
  """Format search results as a DataFrame for display"""
477
  # Find papers synergistic with the given abstract
478
+ # papers = embedding_model.find_synergistic_papers(abstract)
479
+ papers = embedding_model.find_synergistic_papers(dataset, abstract)
480
 
481
  # Convert to DataFrame for display
482
  df = pd.DataFrame(