Spaces:
Running
Running
// Copyright (c) Meta Platforms, Inc. and affiliates. | |
// All rights reserved. | |
// This source code is licensed under the license found in the | |
// LICENSE file in the root directory of this source tree. | |
import { InferenceSession, Tensor } from "onnxruntime-web"; | |
import React, { useContext, useEffect, useState } from "react"; | |
import "./assets/scss/App.scss"; | |
import { handleImageScale } from "./components/helpers/scaleHelper"; | |
import { modelScaleProps } from "./components/helpers/Interfaces"; | |
import { onnxMaskToImage } from "./components/helpers/maskUtils"; | |
import { modelData } from "./components/helpers/onnxModelAPI"; | |
import Stage from "./components/Stage"; | |
import AppContext from "./components/hooks/createContext"; | |
const ort = require("onnxruntime-web"); | |
/* @ts-ignore */ | |
import npyjs from "npyjs"; | |
// Define image, embedding and model paths | |
const IMAGE_PATH = "/assets/data/dogs.jpg"; | |
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; | |
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx"; | |
const App = () => { | |
const { | |
clicks: [clicks], | |
image: [, setImage], | |
maskImg: [, setMaskImg], | |
} = useContext(AppContext)!; | |
const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model | |
const [tensor, setTensor] = useState<Tensor | null>(null); // Image embedding tensor | |
// The ONNX model expects the input to be rescaled to 1024. | |
// The modelScale state variable keeps track of the scale values. | |
const [modelScale, setModelScale] = useState<modelScaleProps | null>(null); | |
// Initialize the ONNX model. load the image, and load the SAM | |
// pre-computed image embedding | |
useEffect(() => { | |
// Initialize the ONNX model | |
const initModel = async () => { | |
try { | |
if (MODEL_DIR === undefined) return; | |
const URL: string = MODEL_DIR; | |
const model = await InferenceSession.create(URL); | |
setModel(model); | |
} catch (e) { | |
console.log(e); | |
} | |
}; | |
initModel(); | |
// Load the image | |
const url = new URL(IMAGE_PATH, location.origin); | |
loadImage(url); | |
// Load the Segment Anything pre-computed embedding | |
Promise.resolve(loadNpyTensor(IMAGE_EMBEDDING, "float32")).then( | |
(embedding) => setTensor(embedding) | |
); | |
}, []); | |
const loadImage = async (url: URL) => { | |
try { | |
const img = new Image(); | |
img.src = url.href; | |
img.onload = () => { | |
const { height, width, samScale } = handleImageScale(img); | |
setModelScale({ | |
height: height, // original image height | |
width: width, // original image width | |
samScale: samScale, // scaling factor for image which has been resized to longest side 1024 | |
}); | |
img.width = width; | |
img.height = height; | |
setImage(img); | |
}; | |
} catch (error) { | |
console.log(error); | |
} | |
}; | |
// Decode a Numpy file into a tensor. | |
const loadNpyTensor = async (tensorFile: string, dType: string) => { | |
let npLoader = new npyjs(); | |
const npArray = await npLoader.load(tensorFile); | |
const tensor = new ort.Tensor(dType, npArray.data, npArray.shape); | |
return tensor; | |
}; | |
// Run the ONNX model every time clicks has changed | |
useEffect(() => { | |
runONNX(); | |
}, [clicks]); | |
const runONNX = async () => { | |
try { | |
if ( | |
model === null || | |
clicks === null || | |
tensor === null || | |
modelScale === null | |
) | |
return; | |
else { | |
// Preapre the model input in the correct format for SAM. | |
// The modelData function is from onnxModelAPI.tsx. | |
const feeds = modelData({ | |
clicks, | |
tensor, | |
modelScale, | |
}); | |
if (feeds === undefined) return; | |
// Run the SAM ONNX model with the feeds returned from modelData() | |
const results = await model.run(feeds); | |
const output = results[model.outputNames[0]]; | |
// The predicted mask returned from the ONNX model is an array which is | |
// rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx. | |
setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3])); | |
} | |
} catch (e) { | |
console.log(e); | |
} | |
}; | |
return <Stage />; | |
}; | |
export default App; | |