bilca's picture
Upload 56 files
352fb85 verified
raw
history blame
22.2 kB
import SortWorker from "web-worker:../utils/SortWorker.ts";
import { ShaderProgram } from "./ShaderProgram";
import { ShaderPass } from "../passes/ShaderPass";
import { RenderData } from "../utils/RenderData";
import { Color32 } from "../../../math/Color32";
import { ObjectAddedEvent, ObjectChangedEvent, ObjectRemovedEvent } from "../../../events/Events";
import { Splat } from "../../../splats/Splat";
import { WebGLRenderer } from "../../WebGLRenderer";
import { Scene } from "../../../core/Scene"
const vertexShaderSource = /* glsl */ `#version 300 es
precision highp float;
precision highp int;
uniform highp usampler2D u_texture;
uniform highp sampler2D u_transforms;
uniform highp usampler2D u_transformIndices;
uniform highp sampler2D u_colorTransforms;
uniform highp usampler2D u_colorTransformIndices;
uniform mat4 projection, view;
uniform vec2 focal;
uniform vec2 viewport;
uniform bool useDepthFade;
uniform float depthFade;
in vec2 position;
in int index;
out vec4 vColor;
out vec2 vPosition;
out float vSize;
out float vSelected;
void main () {
uvec4 cen = texelFetch(u_texture, ivec2((uint(index) & 0x3ffu) << 1, uint(index) >> 10), 0);
float selected = float((cen.w >> 24) & 0xffu);
uint transformIndex = texelFetch(u_transformIndices, ivec2(uint(index) & 0x3ffu, uint(index) >> 10), 0).x;
mat4 transform = mat4(
texelFetch(u_transforms, ivec2(0, transformIndex), 0),
texelFetch(u_transforms, ivec2(1, transformIndex), 0),
texelFetch(u_transforms, ivec2(2, transformIndex), 0),
texelFetch(u_transforms, ivec2(3, transformIndex), 0)
);
if (selected < 0.5) {
selected = texelFetch(u_transforms, ivec2(4, transformIndex), 0).x;
}
mat4 viewTransform = view * transform;
vec4 cam = viewTransform * vec4(uintBitsToFloat(cen.xyz), 1);
vec4 pos2d = projection * cam;
float clip = 1.2 * pos2d.w;
if (pos2d.z < -pos2d.w || pos2d.z > pos2d.w || pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) {
gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
return;
}
uvec4 cov = texelFetch(u_texture, ivec2(((uint(index) & 0x3ffu) << 1) | 1u, uint(index) >> 10), 0);
vec2 u1 = unpackHalf2x16(cov.x), u2 = unpackHalf2x16(cov.y), u3 = unpackHalf2x16(cov.z);
mat3 Vrk = mat3(u1.x, u1.y, u2.x, u1.y, u2.y, u3.x, u2.x, u3.x, u3.y);
mat3 J = mat3(
focal.x / cam.z, 0., -(focal.x * cam.x) / (cam.z * cam.z),
0., -focal.y / cam.z, (focal.y * cam.y) / (cam.z * cam.z),
0., 0., 0.
);
mat3 T = transpose(mat3(viewTransform)) * J;
mat3 cov2d = transpose(T) * Vrk * T;
//ref: https://github.com/graphdeco-inria/diff-gaussian-rasterization/blob/main/cuda_rasterizer/forward.cu#L110-L111
cov2d[0][0] += 0.3;
cov2d[1][1] += 0.3;
float mid = (cov2d[0][0] + cov2d[1][1]) / 2.0;
float radius = length(vec2((cov2d[0][0] - cov2d[1][1]) / 2.0, cov2d[0][1]));
float lambda1 = mid + radius, lambda2 = mid - radius;
if (lambda2 < 0.0) return;
vec2 diagonalVector = normalize(vec2(cov2d[0][1], lambda1 - cov2d[0][0]));
vec2 majorAxis = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
vec2 minorAxis = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);
uint colorTransformIndex = texelFetch(u_colorTransformIndices, ivec2(uint(index) & 0x3ffu, uint(index) >> 10), 0).x;
mat4 colorTransform = mat4(
texelFetch(u_colorTransforms, ivec2(0, colorTransformIndex), 0),
texelFetch(u_colorTransforms, ivec2(1, colorTransformIndex), 0),
texelFetch(u_colorTransforms, ivec2(2, colorTransformIndex), 0),
texelFetch(u_colorTransforms, ivec2(3, colorTransformIndex), 0)
);
vec4 color = vec4((cov.w) & 0xffu, (cov.w >> 8) & 0xffu, (cov.w >> 16) & 0xffu, (cov.w >> 24) & 0xffu) / 255.0;
vColor = colorTransform * color;
vPosition = position;
vSize = length(majorAxis);
vSelected = selected;
float scalingFactor = 1.0;
if (useDepthFade) {
float depthNorm = (pos2d.z / pos2d.w + 1.0) / 2.0;
float near = 0.1; float far = 100.0;
float normalizedDepth = (2.0 * near) / (far + near - depthNorm * (far - near));
float start = max(normalizedDepth - 0.1, 0.0);
float end = min(normalizedDepth + 0.1, 1.0);
scalingFactor = clamp((depthFade - start) / (end - start), 0.0, 1.0);
}
vec2 vCenter = vec2(pos2d) / pos2d.w;
gl_Position = vec4(
vCenter
+ position.x * majorAxis * scalingFactor / viewport
+ position.y * minorAxis * scalingFactor / viewport, 0.0, 1.0);
}
`;
const fragmentShaderSource = /* glsl */ `#version 300 es
precision highp float;
uniform float outlineThickness;
uniform vec4 outlineColor;
in vec4 vColor;
in vec2 vPosition;
in float vSize;
in float vSelected;
out vec4 fragColor;
void main () {
float A = -dot(vPosition, vPosition);
if (A < -4.0) discard;
if (vSelected < 0.5) {
float B = exp(A) * vColor.a;
fragColor = vec4(B * vColor.rgb, B);
return;
}
float outlineThreshold = -4.0 + (outlineThickness / vSize);
if (A < outlineThreshold) {
fragColor = outlineColor;
}
else {
float B = exp(A) * vColor.a;
fragColor = vec4(B * vColor.rgb, B);
}
}
`;
class RenderProgram extends ShaderProgram {
private _outlineThickness: number = 10.0;
private _outlineColor: Color32 = new Color32(255, 165, 0, 255);
private _renderData: RenderData | null = null;
private _depthIndex: Uint32Array = new Uint32Array();
private _splatTexture: WebGLTexture | null = null;
private _worker: Worker | null = null;
protected _initialize: () => void;
protected _resize: () => void;
protected _render: () => void;
protected _dispose: () => void;
private _setOutlineThickness: (value: number) => void;
private _setOutlineColor: (value: Color32) => void;
constructor(renderer: WebGLRenderer, passes: ShaderPass[]) {
super(renderer, passes);
const canvas = renderer.canvas;
const gl = renderer.gl;
let u_projection: WebGLUniformLocation;
let u_viewport: WebGLUniformLocation;
let u_focal: WebGLUniformLocation;
let u_view: WebGLUniformLocation;
let u_texture: WebGLUniformLocation;
let u_transforms: WebGLUniformLocation;
let u_transformIndices: WebGLUniformLocation;
let u_colorTransforms: WebGLUniformLocation;
let u_colorTransformIndices: WebGLUniformLocation;
let u_outlineThickness: WebGLUniformLocation;
let u_outlineColor: WebGLUniformLocation;
let positionAttribute: number;
let indexAttribute: number;
let transformsTexture: WebGLTexture;
let transformIndicesTexture: WebGLTexture;
let colorTransformsTexture: WebGLTexture;
let colorTransformIndicesTexture: WebGLTexture;
let vertexBuffer: WebGLBuffer;
let indexBuffer: WebGLBuffer;
this._resize = () => {
if (!this._camera) return;
this._camera.data.setSize(canvas.width, canvas.height);
this._camera.update();
u_projection = gl.getUniformLocation(this.program, "projection") as WebGLUniformLocation;
gl.uniformMatrix4fv(u_projection, false, this._camera.data.projectionMatrix.buffer);
u_viewport = gl.getUniformLocation(this.program, "viewport") as WebGLUniformLocation;
gl.uniform2fv(u_viewport, new Float32Array([canvas.width, canvas.height]));
};
const createWorker = () => {
this._worker = new SortWorker();
this._worker.onmessage = (e) => {
if (e.data.depthIndex) {
const { depthIndex } = e.data;
this._depthIndex = depthIndex;
gl.bindBuffer(gl.ARRAY_BUFFER, indexBuffer);
gl.bufferData(gl.ARRAY_BUFFER, depthIndex, gl.STATIC_DRAW);
}
};
};
this._initialize = () => {
if (!this._scene || !this._camera) {
console.error("Cannot render without scene and camera");
return;
}
this._resize();
this._scene.addEventListener("objectAdded", handleObjectAdded);
this._scene.addEventListener("objectRemoved", handleObjectRemoved);
for (const object of this._scene.objects) {
if (object instanceof Splat) {
object.addEventListener("objectChanged", handleObjectChanged);
}
}
this._renderData = new RenderData(this._scene);
u_focal = gl.getUniformLocation(this.program, "focal") as WebGLUniformLocation;
gl.uniform2fv(u_focal, new Float32Array([this._camera.data.fx, this._camera.data.fy]));
u_view = gl.getUniformLocation(this.program, "view") as WebGLUniformLocation;
gl.uniformMatrix4fv(u_view, false, this._camera.data.viewMatrix.buffer);
u_outlineThickness = gl.getUniformLocation(this.program, "outlineThickness") as WebGLUniformLocation;
gl.uniform1f(u_outlineThickness, this.outlineThickness);
u_outlineColor = gl.getUniformLocation(this.program, "outlineColor") as WebGLUniformLocation;
gl.uniform4fv(u_outlineColor, new Float32Array(this.outlineColor.flatNorm()));
this._splatTexture = gl.createTexture() as WebGLTexture;
u_texture = gl.getUniformLocation(this.program, "u_texture") as WebGLUniformLocation;
gl.uniform1i(u_texture, 0);
transformsTexture = gl.createTexture() as WebGLTexture;
u_transforms = gl.getUniformLocation(this.program, "u_transforms") as WebGLUniformLocation;
gl.uniform1i(u_transforms, 1);
transformIndicesTexture = gl.createTexture() as WebGLTexture;
u_transformIndices = gl.getUniformLocation(this.program, "u_transformIndices") as WebGLUniformLocation;
gl.uniform1i(u_transformIndices, 2);
colorTransformsTexture = gl.createTexture() as WebGLTexture;
u_colorTransforms = gl.getUniformLocation(this.program, "u_colorTransforms") as WebGLUniformLocation;
gl.uniform1i(u_colorTransforms, 3);
colorTransformIndicesTexture = gl.createTexture() as WebGLTexture;
u_colorTransformIndices = gl.getUniformLocation(
this.program,
"u_colorTransformIndices",
) as WebGLUniformLocation;
gl.uniform1i(u_colorTransformIndices, 4);
vertexBuffer = gl.createBuffer() as WebGLBuffer;
gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
gl.bufferData(gl.ARRAY_BUFFER, new Float32Array([-2, -2, 2, -2, 2, 2, -2, 2]), gl.STATIC_DRAW);
positionAttribute = gl.getAttribLocation(this.program, "position");
gl.enableVertexAttribArray(positionAttribute);
gl.vertexAttribPointer(positionAttribute, 2, gl.FLOAT, false, 0, 0);
indexBuffer = gl.createBuffer() as WebGLBuffer;
indexAttribute = gl.getAttribLocation(this.program, "index");
gl.enableVertexAttribArray(indexAttribute);
gl.bindBuffer(gl.ARRAY_BUFFER, indexBuffer);
createWorker();
};
const handleObjectAdded = (event: Event) => {
const e = event as ObjectAddedEvent;
if (e.object instanceof Splat) {
e.object.addEventListener("objectChanged", handleObjectChanged);
}
resetSplatData();
};
const handleObjectRemoved = (event: Event) => {
const e = event as ObjectRemovedEvent;
if (e.object instanceof Splat) {
e.object.removeEventListener("objectChanged", handleObjectChanged);
}
resetSplatData();
};
const handleObjectChanged = (event: Event) => {
const e = event as ObjectChangedEvent;
if (e.object instanceof Splat && this._renderData) {
this._renderData.markDirty(e.object);
}
};
const resetSplatData = () => {
this._renderData?.dispose();
this._renderData = new RenderData(this._scene as Scene);
this._worker?.terminate();
createWorker();
};
this._render = () => {
if (!this._scene || !this._camera || !this.renderData) {
console.error("Cannot render without scene and camera");
return;
}
if (this.renderData.needsRebuild) {
this.renderData.rebuild();
}
if (
this.renderData.dataChanged ||
this.renderData.transformsChanged ||
this.renderData.colorTransformsChanged
) {
if (this.renderData.dataChanged) {
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, this.splatTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.RGBA32UI,
this.renderData.width,
this.renderData.height,
0,
gl.RGBA_INTEGER,
gl.UNSIGNED_INT,
this.renderData.data,
);
}
if (this.renderData.transformsChanged) {
gl.activeTexture(gl.TEXTURE1);
gl.bindTexture(gl.TEXTURE_2D, transformsTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.RGBA32F,
this.renderData.transformsWidth,
this.renderData.transformsHeight,
0,
gl.RGBA,
gl.FLOAT,
this.renderData.transforms,
);
gl.activeTexture(gl.TEXTURE2);
gl.bindTexture(gl.TEXTURE_2D, transformIndicesTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.R32UI,
this.renderData.transformIndicesWidth,
this.renderData.transformIndicesHeight,
0,
gl.RED_INTEGER,
gl.UNSIGNED_INT,
this.renderData.transformIndices,
);
}
if (this.renderData.colorTransformsChanged) {
gl.activeTexture(gl.TEXTURE3);
gl.bindTexture(gl.TEXTURE_2D, colorTransformsTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.RGBA32F,
this.renderData.colorTransformsWidth,
this.renderData.colorTransformsHeight,
0,
gl.RGBA,
gl.FLOAT,
this.renderData.colorTransforms,
);
gl.activeTexture(gl.TEXTURE4);
gl.bindTexture(gl.TEXTURE_2D, colorTransformIndicesTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.R32UI,
this.renderData.colorTransformIndicesWidth,
this.renderData.colorTransformIndicesHeight,
0,
gl.RED_INTEGER,
gl.UNSIGNED_INT,
this.renderData.colorTransformIndices,
);
}
const detachedPositions = new Float32Array(this.renderData.positions.slice().buffer);
const detachedTransforms = new Float32Array(this.renderData.transforms.slice().buffer);
const detachedTransformIndices = new Uint32Array(this.renderData.transformIndices.slice().buffer);
this._worker?.postMessage(
{
sortData: {
positions: detachedPositions,
transforms: detachedTransforms,
transformIndices: detachedTransformIndices,
vertexCount: this.renderData.vertexCount,
},
},
[detachedPositions.buffer, detachedTransforms.buffer, detachedTransformIndices.buffer],
);
this.renderData.dataChanged = false;
this.renderData.transformsChanged = false;
this.renderData.colorTransformsChanged = false;
}
this._camera.update();
this._worker?.postMessage({ viewProj: this._camera.data.viewProj.buffer });
gl.viewport(0, 0, canvas.width, canvas.height);
gl.clearColor(0, 0, 0, 0);
gl.clear(gl.COLOR_BUFFER_BIT);
gl.disable(gl.DEPTH_TEST);
gl.enable(gl.BLEND);
gl.blendFuncSeparate(gl.ONE_MINUS_DST_ALPHA, gl.ONE, gl.ONE_MINUS_DST_ALPHA, gl.ONE);
gl.blendEquationSeparate(gl.FUNC_ADD, gl.FUNC_ADD);
gl.uniformMatrix4fv(u_projection, false, this._camera.data.projectionMatrix.buffer);
gl.uniformMatrix4fv(u_view, false, this._camera.data.viewMatrix.buffer);
gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
gl.vertexAttribPointer(positionAttribute, 2, gl.FLOAT, false, 0, 0);
gl.bindBuffer(gl.ARRAY_BUFFER, indexBuffer);
gl.bufferData(gl.ARRAY_BUFFER, this.depthIndex, gl.STATIC_DRAW);
gl.vertexAttribIPointer(indexAttribute, 1, gl.INT, 0, 0);
gl.vertexAttribDivisor(indexAttribute, 1);
gl.drawArraysInstanced(gl.TRIANGLE_FAN, 0, 4, this.depthIndex.length);
};
this._dispose = () => {
if (!this._scene || !this._camera || !this.renderData) {
console.error("Cannot dispose without scene and camera");
return;
}
this._scene.removeEventListener("objectAdded", handleObjectAdded);
this._scene.removeEventListener("objectRemoved", handleObjectRemoved);
for (const object of this._scene.objects) {
if (object instanceof Splat) {
object.removeEventListener("objectChanged", handleObjectChanged);
}
}
this._worker?.terminate();
this.renderData.dispose();
gl.deleteTexture(this.splatTexture);
gl.deleteTexture(transformsTexture);
gl.deleteTexture(transformIndicesTexture);
gl.deleteBuffer(indexBuffer);
gl.deleteBuffer(vertexBuffer);
};
this._setOutlineThickness = (value: number) => {
this._outlineThickness = value;
if (this._initialized) {
gl.uniform1f(u_outlineThickness, value);
}
};
this._setOutlineColor = (value: Color32) => {
this._outlineColor = value;
if (this._initialized) {
gl.uniform4fv(u_outlineColor, new Float32Array(value.flatNorm()));
}
};
}
get renderData() {
return this._renderData;
}
get depthIndex() {
return this._depthIndex;
}
get splatTexture() {
return this._splatTexture;
}
get outlineThickness() {
return this._outlineThickness;
}
set outlineThickness(value: number) {
this._setOutlineThickness(value);
}
get outlineColor() {
return this._outlineColor;
}
set outlineColor(value: Color32) {
this._setOutlineColor(value);
}
get worker() {
return this._worker;
}
protected _getVertexSource() {
return vertexShaderSource;
}
protected _getFragmentSource() {
return fragmentShaderSource;
}
}
export { RenderProgram };