LukasHug commited on
Commit
e4484f6
·
1 Parent(s): 88c2435

allow multiple rules

Browse files
VerifiableRewardsForScalableLogicalReasoning.py CHANGED
@@ -91,7 +91,7 @@ Args:
91
  references (`list` of `dict`): Each reference should contain:
92
  - 'validation_program' (`str`): Background knowledge in Prolog syntax
93
  - 'evaluation_config' (`dict`, optional): Configuration of predicates to use for evaluation.
94
- Define: positive_predicate, and negative_predicate, the positive one should match the head of the rule to evaluate.
95
  Returns:
96
  accuracy (`float`): The proportion of predictions that correctly solve all examples. Value is between 0 and 1.
97
  partial_score (`float`): Average proportion of correctly classified examples across all predictions. Value is between 0 and 1.
@@ -130,9 +130,10 @@ def _evaluate_with_prolog(prediction, validation_program, eval_config, timeout=5
130
  # Extract configuration
131
  positive_pred = eval_config.get("positive_predicate", "eastbound")
132
  negative_pred = eval_config.get("negative_predicate", "westbound")
 
133
 
134
  # extract predicate from rule_to_evaluate
135
- rule_to_evaluate = extract_ilp_from_text_v2(prediction, positive_pred)
136
  if positive_pred not in rule_to_evaluate:
137
  logger.warning(f"Rule '{rule_to_evaluate}' does not contain positive predicate '{positive_pred}'")
138
  return {
@@ -241,16 +242,16 @@ def extract_ilp_from_text(text):
241
  return p_code
242
 
243
 
244
- def extract_ilp_from_text_v2(text, target_predicate=None):
245
  text = re.sub(r'%.*?(?=\n|$)', '', text) # remove comments
246
  # Pre-process: collapse code blocks to single lines
247
  text = re.sub(r'\n\s*', ' ', text) # crude: flatten all to one line
248
  # Rule pattern, across newlines
249
  rule_pattern = re.compile(rf'({target_predicate}\([^()]*\)\s*:-.*?\.)')
250
  rules = list(rule_pattern.findall(text))
251
- if len(rules) > 1:
252
- logger.warning(f"Found multiple rules in text: {rules}. Using only the first one.")
253
- rules = rules[:1] # Use only the first match
254
  # Remove rules that are also captured as facts
255
  p_code = ''
256
  for rule in rules:
 
91
  references (`list` of `dict`): Each reference should contain:
92
  - 'validation_program' (`str`): Background knowledge in Prolog syntax
93
  - 'evaluation_config' (`dict`, optional): Configuration of predicates to use for evaluation.
94
+ Define: positive_predicate, and negative_predicate, the positive one should match the head of the rule to evaluate.
95
  Returns:
96
  accuracy (`float`): The proportion of predictions that correctly solve all examples. Value is between 0 and 1.
97
  partial_score (`float`): Average proportion of correctly classified examples across all predictions. Value is between 0 and 1.
 
130
  # Extract configuration
131
  positive_pred = eval_config.get("positive_predicate", "eastbound")
132
  negative_pred = eval_config.get("negative_predicate", "westbound")
133
+ allow_multiple_rules = eval_config.get("allow_multiple_rules", True)
134
 
135
  # extract predicate from rule_to_evaluate
136
+ rule_to_evaluate = extract_ilp_from_text_v2(prediction, positive_pred, allow_multiple_rules)
137
  if positive_pred not in rule_to_evaluate:
138
  logger.warning(f"Rule '{rule_to_evaluate}' does not contain positive predicate '{positive_pred}'")
139
  return {
 
242
  return p_code
243
 
244
 
245
+ def extract_ilp_from_text_v2(text, target_predicate=None, allow_multiple_rules=False):
246
  text = re.sub(r'%.*?(?=\n|$)', '', text) # remove comments
247
  # Pre-process: collapse code blocks to single lines
248
  text = re.sub(r'\n\s*', ' ', text) # crude: flatten all to one line
249
  # Rule pattern, across newlines
250
  rule_pattern = re.compile(rf'({target_predicate}\([^()]*\)\s*:-.*?\.)')
251
  rules = list(rule_pattern.findall(text))
252
+ if len(rules) > 1 and not allow_multiple_rules:
253
+ logger.warning(f"Found multiple rules in text, but allow_multiple_rules is set to False. Using only the first match.")
254
+ rules = rules[:1]
255
  # Remove rules that are also captured as facts
256
  p_code = ''
257
  for rule in rules: