|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Verifiable Logic Rewards """ |
|
|
|
import os |
|
import subprocess |
|
import tempfile |
|
import evaluate |
|
import logging |
|
import datasets |
|
from tqdm import tqdm |
|
import time |
|
import multiprocessing as mp |
|
import re |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
_CITATION = """\ |
|
@misc{anonymous2025slr, |
|
author = {Anonymous}, |
|
title = {SLR: Anonymous}, |
|
year = {2025} |
|
} |
|
""" |
|
|
|
_DESCRIPTION = """\ |
|
Verifiable Rewards for Scalable Logical Reasoning (**SLR**) provides verifiable rewards via logic programm execution. |
|
It deterministically evaluates candidate hypotheses by executing them against the validation program and verifying all positive examples ($E^+$) are entailed and all negative examples ($E^-$) are not entailed . |
|
Evaluations performed are fully verifiable and grounded in formal logic, ensuring an automatic, transparent, and reproducible standard for evaluation and reward in both supervised and reinforcement learning settings. |
|
How it Works: |
|
- Input: A candidate hypothesis (logic rule) and an executable validation program containing background knowledge and examples. |
|
- Execution: The candidate rule is executed against the validation program using a Prolog interpreter. |
|
- Correctness Criteria: The rule is considered correct if it entails all positive examples and rejects all negative examples. |
|
- Metrics: We provide a range of evaluation metrics (detailed below). |
|
- Usage: see **Documentation tab** for details on how to use Verifiable Rewards for Scalable Logical Reasoning in your own projects. |
|
|
|
Example usage: |
|
from evaluate import load |
|
|
|
symbolic_judge = load("LG-Anonym/VerifiableRewardsForScalableLogicalReasoning") |
|
|
|
validation_program = \"\"\" |
|
eastbound(train0). |
|
has_car(train0, car0_1). |
|
car_num(car0_1, 1). |
|
car_color(car0_1, white). |
|
car_len(car0_1, short). |
|
has_wall(car0_1, full). |
|
|
|
westbound(train1). |
|
has_car(train1, car1_1). |
|
car_num(car1_1, 1). |
|
car_color(car1_1, yellow). |
|
car_len(car1_1, short). |
|
has_wall(car1_1, full). |
|
\"\"\" |
|
|
|
predicted_rule = "eastbound(Train):- has_car(Train, Car1), car_color(Car1, white)." |
|
|
|
results = symbolic_judge.compute( |
|
predictions=[predicted_rule], |
|
references=[{"validation_program": validation_program, |
|
"evaluation_config": { |
|
"positive_predicate": "eastbound", |
|
"negative_predicate": "westbound" |
|
}}] |
|
) |
|
|
|
Note: A local Prolog interpreter is required to execute validation programs. |
|
""" |
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Args: |
|
predictions (`list` of `str`): Each prediction should be a Prolog rule like "pred(T) :- Body." |
|
references (`list` of `dict`): Each reference should contain: |
|
- 'validation_program' (`str`): Background knowledge in Prolog syntax |
|
- 'evaluation_config' (`dict`, optional): Configuration of predicates to use for evaluation. |
|
Define: positive_predicate, and negative_predicate, the positive one should match the head of the rule to evaluate. |
|
Returns: |
|
accuracy (`float`): The proportion of predictions that correctly solve all examples. Value is between 0 and 1. |
|
partial_score (`float`): Average proportion of correctly classified examples across all predictions. Value is between 0 and 1. |
|
syntax_score (`float`): Proportion of rules with valid syntax. Value is between 0 and 1. |
|
detailed_results (`list` of `dict`): Per-example results including correctness, partial score, execution time, and any errors encountered. |
|
""" |
|
|
|
|
|
def validate_rule_no_hardcoded_cars(prediction): |
|
"""Reject rules that hardcode specific car identifiers""" |
|
import re |
|
|
|
|
|
hardcoded_pattern = r'has_car\([^,]+,\s*([a-z][a-z0-9_]*)\)' |
|
matches = re.findall(hardcoded_pattern, prediction) |
|
|
|
if matches: |
|
return False, f"Cars must be variables: {matches[0]}" |
|
|
|
return True, "Rule is valid" |
|
|
|
|
|
def _evaluate_with_prolog(prediction, validation_program, eval_config, timeout=5): |
|
""" |
|
Evaluates a predicted rule against the validation program using Prolog. |
|
""" |
|
|
|
|
|
|
|
positive_pred = eval_config.get("positive_predicate", "eastbound") |
|
negative_pred = eval_config.get("negative_predicate", "westbound") |
|
allow_multiple_rules = eval_config.get("allow_multiple_rules", False) |
|
|
|
|
|
rule_to_evaluate = extract_ilp_from_text_v2(prediction, positive_pred, allow_multiple_rules) |
|
|
|
is_valid, validation_msg = validate_rule_no_hardcoded_cars(rule_to_evaluate) |
|
if not is_valid: |
|
return { |
|
"is_correct": False, |
|
"partial_score": 0.0, |
|
"syntax_valid": False, |
|
"error": f"Rule validation failed: {validation_msg}" |
|
} |
|
|
|
if positive_pred not in rule_to_evaluate: |
|
p = prediction.replace('\n', ' ') |
|
return { |
|
"is_correct": False, |
|
"partial_score": 0.0, |
|
"syntax_valid": False, |
|
"error": f"Invalid Syntax: Logic Rule not found for symbol '{positive_pred}': {p}" |
|
} |
|
|
|
pos_examples = re.findall(rf'{positive_pred}\(([^)]+)\)', validation_program) |
|
neg_examples = re.findall(rf'{negative_pred}\(([^)]+)\)', validation_program) |
|
|
|
|
|
arity = 1 |
|
if pos_examples: |
|
arity = pos_examples[0].count(',') + 1 |
|
elif neg_examples: |
|
arity = neg_examples[0].count(',') + 1 |
|
|
|
|
|
vars = ", ".join([f"X{i}" for i in range(1, arity + 1)]) |
|
|
|
symbolic_judge = f""" |
|
% Dynamic evaluation predicates |
|
check({vars}) :- pos({vars}), {positive_pred}({vars}). % positive covered |
|
check({vars}) :- neg({vars}), \\+ {positive_pred}({vars}). % negative rejected |
|
|
|
% Count successful checks |
|
check_count(Count) :- |
|
(setof(({vars}), ((pos({vars}); neg({vars})), check({vars})), CorrectExamples) -> |
|
length(CorrectExamples, Count) |
|
; |
|
Count = 0 |
|
). |
|
|
|
check_all :- forall((pos({vars});neg({vars})), check({vars})). |
|
""" |
|
|
|
validation_program = re.sub(rf'\b{positive_pred}\b', 'pos', validation_program) |
|
validation_program = re.sub(rf'\b{negative_pred}\b', 'neg', validation_program) |
|
|
|
pos_negs = validation_program.count("pos(") + validation_program.count("neg(") |
|
validation_program = '\n'.join(sorted(validation_program.splitlines())) |
|
full_program = validation_program + "\n\n" + symbolic_judge + "\n\n" + rule_to_evaluate + "\n\n" |
|
|
|
with tempfile.NamedTemporaryFile(suffix='.pl', mode='w', delete=False) as f: |
|
f.write(full_program) |
|
temp_file = f.name |
|
|
|
try: |
|
eval_start_time = time.time() |
|
|
|
cmd = ['swipl', '-s', temp_file, '-g', 'check_count(Count), writeln(Count)', '-t', 'halt'] |
|
result = subprocess.run( |
|
cmd, |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.PIPE, |
|
timeout=timeout, |
|
text=True |
|
) |
|
partial_score = 0.0 if result.stdout.strip() == '' else int(result.stdout.strip()) |
|
|
|
partial_score = partial_score / pos_negs if pos_negs > 0 else 0.0 |
|
|
|
is_correct = True if partial_score == 1.0 else False |
|
|
|
error = f'{result.stderr} -> Eval Rule "{rule_to_evaluate}"' if result.stderr else None |
|
t1 = time.time() |
|
|
|
return { |
|
"is_correct": is_correct, |
|
"partial_score": partial_score, |
|
"syntax_valid": True, |
|
"error": error, |
|
"exec_time1": t1 - eval_start_time, |
|
} |
|
|
|
except subprocess.TimeoutExpired: |
|
r = rule_to_evaluate.replace('\n', ' ') |
|
return {"is_correct": False, "partial_score": 0.0, "syntax_valid": False, |
|
"error": "Evaluation timed out after {timeout} seconds for rule: '{r}'"} |
|
except Exception as e: |
|
return {"is_correct": False, "partial_score": 0.0, "syntax_valid": False, |
|
"error": f"Error evaluating rule '{rule_to_evaluate}' returns: '{result.stdout.strip() if result else 'No error message'}' with error: {e}"} |
|
finally: |
|
if os.path.exists(temp_file): |
|
os.remove(temp_file) |
|
|
|
def extract_ilp_from_text(text): |
|
rule_patterns = [ |
|
|
|
r'([a-zA-Z_][a-zA-Z0-9_]*\([^)]*\)\s*:-[^.]*\.)', |
|
|
|
|
|
] |
|
p_code = '' |
|
for pattern in rule_patterns: |
|
matches = re.findall(pattern, text) |
|
for match in matches: |
|
|
|
statement = match.strip() |
|
if not statement.endswith('.'): |
|
statement += '.' |
|
p_code += statement + '\n' |
|
return p_code |
|
|
|
|
|
def extract_ilp_from_text_v2(text, target_predicate=None, allow_multiple_rules=False): |
|
text = re.sub(r'%.*?(?=\n|$)', '', text) |
|
|
|
text = re.sub(r'\n\s*', ' ', text) |
|
|
|
rule_pattern = re.compile(rf'({target_predicate}\([^()]*\)\s*:-.*?\.)') |
|
rules = list(rule_pattern.findall(text)) |
|
if len(rules) > 1 and not allow_multiple_rules: |
|
|
|
rules = rules[-1:] |
|
|
|
p_code = '' |
|
for rule in rules: |
|
|
|
statement = rule.strip() |
|
if not statement.endswith('.'): |
|
statement += '.' |
|
p_code += statement + '\n' |
|
return p_code.strip() |
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class VerifiableRewardsForScalableLogicalReasoning(evaluate.Metric): |
|
def __init__(self, config_name=None, **kwargs): |
|
""" |
|
Initializes the PrologEval metric. |
|
|
|
Args: |
|
config_name (str, optional): Name of the configuration to use. |
|
**kwargs: Additional keyword arguments. |
|
""" |
|
super().__init__(config_name=config_name, **kwargs) |
|
self.config_name = config_name or "default" |
|
self._info = self._info() |
|
self._download_and_prepare(dl_manager=None) |
|
|
|
def _info(self): |
|
return evaluate.MetricInfo( |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
features=datasets.Features({ |
|
'predictions': datasets.Value('string'), |
|
'references': { |
|
'validation_program': datasets.Value('string'), |
|
'evaluation_config': { |
|
'positive_predicate': datasets.Value('string'), |
|
'negative_predicate': datasets.Value('string') |
|
} |
|
}, |
|
}), |
|
codebase_urls=["https://github.com/LG-Anonym/SLR-Bench"], |
|
reference_urls=["https://huggingface.co/datasets/LG-Anonym/SLR-Bench"] |
|
) |
|
|
|
def _download_and_prepare(self, dl_manager): |
|
"""Checks if SWI-Prolog is installed or warns the user.""" |
|
try: |
|
subprocess.run( |
|
["swipl", "--version"], |
|
stdout=subprocess.PIPE, |
|
stderr=subprocess.PIPE, |
|
check=True |
|
) |
|
except (subprocess.CalledProcessError, FileNotFoundError): |
|
logger.warning( |
|
"SWI-Prolog not found. Please install it:\n" |
|
"Ubuntu/Debian: sudo apt-get install swi-prolog\n" |
|
"macOS: brew install swi-prolog\n" |
|
"Windows: download from https://www.swi-prolog.org/download/stable" |
|
) |
|
|
|
def _compute(self, predictions: list, references: list): |
|
"""Calculates the accuracy of predictions using Prolog for evaluation with multiprocessing.""" |
|
if not isinstance(predictions, list): |
|
predictions = [predictions] |
|
|
|
if len(predictions) != len(references): |
|
raise ValueError( |
|
f"Number of predictions ({len(predictions)}) and references {len(references)}) don't match") |
|
|
|
TIMEOUT = 10 if len(predictions) > 500 else 5 |
|
|
|
eval_inputs = [] |
|
for i, (prediction, reference) in enumerate(zip(predictions, references)): |
|
validation_program = reference.get("validation_program", reference.get("validation program")) |
|
|
|
|
|
|
|
eval_config = reference.get("evaluation_config", { |
|
"positive_predicate": "eastbound", |
|
"negative_predicate": "westbound" |
|
}) |
|
|
|
if not validation_program: |
|
raise ValueError(f"Example {i} does not contain validation program field") |
|
|
|
eval_inputs.append((prediction, validation_program, eval_config, TIMEOUT)) |
|
|
|
|
|
if len(eval_inputs) > 500: |
|
|
|
num_cpus = max(1, mp.cpu_count() - 1) |
|
with mp.Pool(processes=num_cpus) as pool: |
|
results = list(tqdm( |
|
pool.starmap(_evaluate_with_prolog, eval_inputs), |
|
total=len(eval_inputs), |
|
desc=f"Evaluating rules (parallel processing with {num_cpus} CPUs)" |
|
)) |
|
else: |
|
|
|
results = [] |
|
for prediction, validation_program, eval_config, t in tqdm(eval_inputs, total=len(predictions), desc="Evaluating rules"): |
|
results.append(_evaluate_with_prolog(prediction, validation_program, eval_config, timeout=t)) |
|
|
|
|
|
partial_scores = [result["partial_score"] for result in results] |
|
correct_count = sum(1 for result in results if result["is_correct"]) |
|
syntax_valid_count = sum(1 for result in results if result["syntax_valid"]) |
|
|
|
accuracy = correct_count / len(predictions) if predictions else 0 |
|
partial_score = sum(partial_scores) / len(predictions) if partial_scores else 0 |
|
syntax_score = syntax_valid_count / len(predictions) if predictions else 0 |
|
|
|
return { |
|
"accuracy": accuracy, |
|
"partial_score": partial_score, |
|
"syntax_score": syntax_score, |
|
"detailed_results": results |
|
} |
|
|