astro-seg / model.py
rayh's picture
Upload model.py with huggingface_hub
543627a verified
raw
history blame
455 Bytes
import onnxruntime as ort
import numpy as np
from PIL import Image
class YOLOSegmentationModel:
def __init__(self, model_path: str):
self.session = ort.InferenceSession(model_path)
def predict(self, image: Image):
input_data = np.array(image).astype(np.float32)
input_data = np.expand_dims(input_data, axis=0) # Add batch dimension
outputs = self.session.run(None, {"images": input_data})
return outputs