|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import { InferenceSession, Tensor } from "onnxruntime-web"; |
|
|
import React, { useContext, useEffect, useState } from "react"; |
|
|
import "./assets/scss/App.scss"; |
|
|
import { handleImageScale } from "./components/helpers/scaleHelper"; |
|
|
import { modelScaleProps } from "./components/helpers/Interfaces"; |
|
|
import { onnxMaskToImage } from "./components/helpers/maskUtils"; |
|
|
import { modelData } from "./components/helpers/onnxModelAPI"; |
|
|
import Stage from "./components/Stage"; |
|
|
import AppContext from "./components/hooks/createContext"; |
|
|
const ort = require("onnxruntime-web"); |
|
|
|
|
|
import npyjs from "npyjs"; |
|
|
|
|
|
|
|
|
const IMAGE_PATH = "/assets/data/dogs.jpg"; |
|
|
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy"; |
|
|
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx"; |
|
|
|
|
|
const App = () => { |
|
|
const { |
|
|
clicks: [clicks], |
|
|
image: [, setImage], |
|
|
maskImg: [, setMaskImg], |
|
|
} = useContext(AppContext)!; |
|
|
const [model, setModel] = useState<InferenceSession | null>(null); |
|
|
const [tensor, setTensor] = useState<Tensor | null>(null); |
|
|
|
|
|
|
|
|
|
|
|
const [modelScale, setModelScale] = useState<modelScaleProps | null>(null); |
|
|
|
|
|
|
|
|
|
|
|
useEffect(() => { |
|
|
|
|
|
const initModel = async () => { |
|
|
try { |
|
|
if (MODEL_DIR === undefined) return; |
|
|
const URL: string = MODEL_DIR; |
|
|
const model = await InferenceSession.create(URL); |
|
|
setModel(model); |
|
|
} catch (e) { |
|
|
console.log(e); |
|
|
} |
|
|
}; |
|
|
initModel(); |
|
|
|
|
|
|
|
|
const url = new URL(IMAGE_PATH, location.origin); |
|
|
loadImage(url); |
|
|
|
|
|
|
|
|
Promise.resolve(loadNpyTensor(IMAGE_EMBEDDING, "float32")).then( |
|
|
(embedding) => setTensor(embedding) |
|
|
); |
|
|
}, []); |
|
|
|
|
|
const loadImage = async (url: URL) => { |
|
|
try { |
|
|
const img = new Image(); |
|
|
img.src = url.href; |
|
|
img.onload = () => { |
|
|
const { height, width, samScale } = handleImageScale(img); |
|
|
setModelScale({ |
|
|
height: height, |
|
|
width: width, |
|
|
samScale: samScale, |
|
|
}); |
|
|
img.width = width; |
|
|
img.height = height; |
|
|
setImage(img); |
|
|
}; |
|
|
} catch (error) { |
|
|
console.log(error); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
const loadNpyTensor = async (tensorFile: string, dType: string) => { |
|
|
let npLoader = new npyjs(); |
|
|
const npArray = await npLoader.load(tensorFile); |
|
|
const tensor = new ort.Tensor(dType, npArray.data, npArray.shape); |
|
|
return tensor; |
|
|
}; |
|
|
|
|
|
|
|
|
useEffect(() => { |
|
|
runONNX(); |
|
|
}, [clicks]); |
|
|
|
|
|
const runONNX = async () => { |
|
|
try { |
|
|
if ( |
|
|
model === null || |
|
|
clicks === null || |
|
|
tensor === null || |
|
|
modelScale === null |
|
|
) |
|
|
return; |
|
|
else { |
|
|
|
|
|
|
|
|
const feeds = modelData({ |
|
|
clicks, |
|
|
tensor, |
|
|
modelScale, |
|
|
}); |
|
|
if (feeds === undefined) return; |
|
|
|
|
|
const results = await model.run(feeds); |
|
|
const output = results[model.outputNames[0]]; |
|
|
|
|
|
|
|
|
setMaskImg(onnxMaskToImage(output.data, output.dims[2], output.dims[3])); |
|
|
} |
|
|
} catch (e) { |
|
|
console.log(e); |
|
|
} |
|
|
}; |
|
|
|
|
|
return <Stage />; |
|
|
}; |
|
|
|
|
|
export default App; |
|
|
|