Spaces:
Running
Running
File size: 4,245 Bytes
07f408f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
import { InferenceSession, Tensor } from "onnxruntime-web";
import React, { useContext, useEffect, useState } from "react";
import "./assets/scss/App.scss";
import { handleImageScale } from "./components/helpers/scaleHelper";
import { modelScaleProps } from "./components/helpers/Interfaces";
import { onnxMaskToImage } from "./components/helpers/maskUtils";
import { modelData } from "./components/helpers/onnxModelAPI";
import Stage from "./components/Stage";
import AppContext from "./components/hooks/createContext";
const ort = require("onnxruntime-web");
/* @ts-ignore */
import npyjs from "npyjs";
// Define image, embedding and model paths
const IMAGE_PATH = "/assets/data/dogs.jpg";
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
const App = () => {
const {
clicks: [clicks],
image: [, setImage],
maskImg: [, setMaskImg],
} = useContext(AppContext)!;
const [model, setModel] = useState<InferenceSession | null>(null); // ONNX model
const [tensor, setTensor] = useState<Tensor | null>(null); // Image embedding tensor
// The ONNX model expects the input to be rescaled to 1024.
// The modelScale state variable keeps track of the scale values.
const [modelScale, setModelScale] = useState<modelScaleProps | null>(null);
// Initialize the ONNX model. load the image, and load the SAM
// pre-computed image embedding
useEffect(() => {
// Initialize the ONNX model
const initModel = async () => {
try {
if (MODEL_DIR === undefined) return;
const URL: string = MODEL_DIR;
const model = await InferenceSession.create(URL);
setModel(model);
} catch (e) {
console.log(e);
}
};
initModel();
// Load the image
const url = new URL(IMAGE_PATH, location.origin);
loadImage(url);
// Load the Segment Anything pre-computed embedding
Promise.resolve(loadNpyTensor(IMAGE_EMBEDDING, "float32")).then(
(embedding) => setTensor(embedding)
);
}, []);
const loadImage = async (url: URL) => {
try {
const img = new Image();
img.src = url.href;
img.onload = () => {
const { height, width, samScale } = handleImageScale(img);
setModelScale({
height: height, // original image height
width: width, // original image width
samScale: samScale, // scaling factor for image which has been resized to longest side 1024
});
img.width = width;
img.height = height;
setImage(img);
};
} catch (error) {
console.log(error);
}
};
// Decode a Numpy file into a tensor.
const loadNpyTensor = async (tensorFile: string, dType: string) => {
let npLoader = new npyjs();
const npArray = await npLoader.load(tensorFile);
const tensor = new ort.Tensor(dType, npArray.data, npArray.shape);
return tensor;
};
// Run the ONNX model every time clicks has changed
useEffect(() => {
runONNX();
}, [clicks]);
const runONNX = async () => {
try {
if (
model === null ||
clicks === null ||
tensor === null ||
modelScale === null
)
return;
else {
// Preapre the model input in the correct format for SAM.
// The modelData function is from onnxModelAPI.tsx.
const feeds = modelData({
clicks,
tensor,
modelScale,
});
if (feeds === undefined) return;
// Run the SAM ONNX model with the feeds returned from modelData()
const results = await model.run(feeds);
const output = results[model.outputNames[0]];
// The predicted mask returned from the ONNX model is an array which is
// rendered as an HTML image using onnxMaskToImage() from maskUtils.tsx.
setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3]));
}
} catch (e) {
console.log(e);
}
};
return <Stage />;
};
export default App;
|