Spaces:
Running
Running
import json | |
import traceback | |
import warnings | |
import mlcroissant as mlc | |
import func_timeout | |
WAIT_TIME = 5 * 60 # seconds | |
def validate_json(file_path): | |
"""Validate that the file is proper JSON.""" | |
with warnings.catch_warnings(record=True) as caught_warnings: | |
warnings.simplefilter("always") # Ensure all warnings are captured | |
try: | |
with open(file_path, 'r') as f: | |
json_data = json.load(f) | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
if warning_msgs: | |
return True, f"β The file is valid JSON.\n\nWarnings:\n" + "\n".join(warning_msgs), json_data | |
return True, "β The file is valid JSON.", json_data | |
except json.JSONDecodeError as e: | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
error_message = f"β Invalid JSON format: {str(e)}" | |
if warning_msgs: | |
error_message += "\n\nWarnings:\n" + "\n".join(warning_msgs) | |
return False, error_message, None | |
except Exception as e: | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
error_message = f"β Error reading file: {str(e)}" | |
if warning_msgs: | |
error_message += "\n\nWarnings:\n" + "\n".join(warning_msgs) | |
return False, error_message, None | |
def validate_croissant(json_data): | |
"""Validate that the JSON follows Croissant schema.""" | |
with warnings.catch_warnings(record=True) as caught_warnings: | |
warnings.simplefilter("always") # Ensure all warnings are captured | |
try: | |
dataset = mlc.Dataset(jsonld=json_data) | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
if warning_msgs: | |
return True, f"β The dataset passes Croissant validation.\n\nWarnings:\n" + "\n".join(warning_msgs) | |
return True, "β The dataset passes Croissant validation." | |
except mlc.ValidationError as e: | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
error_details = traceback.format_exc() | |
error_message = f"β Validation failed: {str(e)}\n\n{error_details}" | |
if warning_msgs: | |
error_message += "\n\nWarnings:\n" + "\n".join(warning_msgs) | |
return False, error_message | |
except Exception as e: | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
error_details = traceback.format_exc() | |
error_message = f"β Unexpected error during validation: {str(e)}\n\n{error_details}" | |
if warning_msgs: | |
error_message += "\n\nWarnings:\n" + "\n".join(warning_msgs) | |
return False, error_message | |
def validate_records(json_data): | |
"""Validate that records can be generated within the time limit.""" | |
with warnings.catch_warnings(record=True) as caught_warnings: | |
warnings.simplefilter("always") # Ensure all warnings are captured | |
try: | |
dataset = mlc.Dataset(jsonld=json_data) | |
record_sets = dataset.metadata.record_sets | |
if not record_sets: | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
msg = "β No record sets found to validate." | |
if warning_msgs: | |
msg += "\n\nWarnings:\n" + "\n".join(warning_msgs) | |
return True, msg | |
results = [] | |
all_warnings = [] | |
for record_set in record_sets: | |
# Capture warnings for each record set separately | |
with warnings.catch_warnings(record=True) as record_warnings: | |
warnings.simplefilter("always") | |
try: | |
records = dataset.records(record_set=record_set.uuid) | |
print(f"Attempting to validate record set: {record_set.uuid}") | |
_ = func_timeout.func_timeout(WAIT_TIME, lambda: next(iter(records))) | |
# Add any warnings from this record set | |
warning_msgs = [str(w.message) for w in record_warnings] | |
if warning_msgs: | |
all_warnings.extend(warning_msgs) | |
results.append(f"β Record set '{record_set.uuid}' passed validation.") | |
except func_timeout.exceptions.FunctionTimedOut: | |
warning_msgs = [str(w.message) for w in record_warnings] | |
msg = f"β οΈ Record set '{record_set.uuid}' generation took too long (>300s)" | |
if warning_msgs: | |
msg += "\n\nWarnings:\n" + "\n".join(warning_msgs) | |
results.append(msg) | |
except Exception as e: | |
warning_msgs = [str(w.message) for w in record_warnings] | |
error_details = traceback.format_exc() | |
msg = f"β οΈ Record set '{record_set.uuid}' encountered an issue: {str(e)}\n\n{error_details}" | |
if warning_msgs: | |
msg += "\n\nWarnings:\n" + "\n".join(warning_msgs) | |
results.append(msg) | |
# Add any warnings from the initial setup | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
if warning_msgs: | |
all_warnings.extend(warning_msgs) | |
final_message = "\n".join(results) | |
if all_warnings: | |
final_message += "\n\nWarnings:\n" + "\n".join(all_warnings) | |
return True, final_message | |
except Exception as e: | |
warning_msgs = [str(w.message) for w in caught_warnings] | |
error_details = traceback.format_exc() | |
error_message = f"β οΈ Unexpected error during records validation: {str(e)}\n\n{error_details}" | |
if warning_msgs: | |
error_message += "\n\nWarnings:\n" + "\n".join(warning_msgs) | |
return True, error_message |