|
import numpy as np |
|
from torchvision import transforms |
|
from PIL import Image |
|
import logging |
|
|
|
def preprocess_onnx_input(image, preprocessor_config): |
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
initial_resize_size = preprocessor_config.get('size', {'height': 224, 'width': 224}) |
|
crop_size = preprocessor_config.get('crop_size', initial_resize_size['height']) |
|
mean = preprocessor_config.get('image_mean', [0.485, 0.456, 0.406]) |
|
std = preprocessor_config.get('image_std', [0.229, 0.224, 0.225]) |
|
transform = transforms.Compose([ |
|
transforms.Resize((initial_resize_size['height'], initial_resize_size['width'])), |
|
transforms.CenterCrop(crop_size), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=mean, std=std), |
|
]) |
|
input_tensor = transform(image) |
|
return input_tensor.unsqueeze(0).cpu().numpy() |
|
|
|
def postprocess_onnx_output(onnx_output, model_config): |
|
logger = logging.getLogger(__name__) |
|
class_names_map = model_config.get('id2label') |
|
if class_names_map: |
|
class_names = [class_names_map[k] for k in sorted(class_names_map.keys())] |
|
elif model_config.get('num_classes') == 1: |
|
class_names = ['Fake', 'Real'] |
|
else: |
|
class_names = {0: 'Fake', 1: 'Real'} |
|
class_names = [class_names[i] for i in sorted(class_names.keys())] |
|
probabilities = onnx_output.get("probabilities") |
|
if probabilities is not None: |
|
if model_config.get('num_classes') == 1 and len(probabilities) == 2: |
|
fake_prob = float(probabilities[0]) |
|
real_prob = float(probabilities[1]) |
|
return {class_names[0]: fake_prob, class_names[1]: real_prob} |
|
elif len(probabilities) == len(class_names): |
|
return {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} |
|
else: |
|
logger.warning("ONNX post-processing: Probabilities length mismatch with class names.") |
|
return {name: 0.0 for name in class_names} |
|
else: |
|
logger.warning("ONNX post-processing failed: 'probabilities' key not found in output.") |
|
return {name: 0.0 for name in class_names} |
|
|