VisuTwin Canvas
C++ 3D Engine — Metal Backend
Loading...
Searching...
No Matches
metalDepthAwareBlurPass.cpp
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2025-2026 Arnis Lektauers
3//
4// Depth-aware bilateral blur pass implementation.
5// Shader ported from upstream scene/shader-lib/glsl/chunks/render-pass/frag/depthAwareBlur.js
6//
8
9#include "metalComposePass.h"
10#include "metalGraphicsDevice.h"
11#include "metalRenderPipeline.h"
12#include "metalTexture.h"
13#include "metalVertexBuffer.h"
22#include "spdlog/spdlog.h"
23
24namespace visutwin::canvas
25{
26 namespace
27 {
28 // Depth-aware blur (GLSL to Metal).
29 // Bilateral blur filter respecting depth discontinuities to avoid halo artifacts.
30 // The HORIZONTAL define is prepended at compile time based on the pass direction.
31 constexpr const char* BLUR_SOURCE_HORIZONTAL = R"(
32#include <metal_stdlib>
33using namespace metal;
34
35#define HORIZONTAL 1
36
37struct ComposeVertexIn {
38 float3 position [[attribute(0)]];
39 float3 normal [[attribute(1)]];
40 float2 uv0 [[attribute(2)]];
41 float4 tangent [[attribute(3)]];
42 float2 uv1 [[attribute(4)]];
43};
44
45struct BlurVarying {
46 float4 position [[position]];
47 float2 uv;
48};
49
50struct BlurUniforms {
51 float2 sourceInvResolution;
52 int filterSize;
53 float cameraNear;
54 float cameraFar;
55};
56
57vertex BlurVarying blurVertex(ComposeVertexIn in [[stage_in]])
58{
59 BlurVarying out;
60 out.position = float4(in.position, 1.0);
61 out.uv = in.uv0;
62 return out;
63}
64
65static inline float getLinearDepth(float rawDepth, float cameraNear, float cameraFar)
66{
67 return (cameraNear * cameraFar) / (cameraFar - rawDepth * (cameraFar - cameraNear));
68}
69
70static inline float bilateralWeight(float depth, float sampleDepth)
71{
72 float diff = (sampleDepth - depth);
73 return max(0.0, 1.0 - diff * diff);
74}
75
76static inline void tap(thread float& sum, thread float& totalWeight, float weight, float depth,
77 float2 position, texture2d<float> sourceTexture, depth2d<float> depthTexture,
78 sampler linearSampler, float cameraNear, float cameraFar)
79{
80 float color = sourceTexture.sample(linearSampler, position).r;
81 float textureDepth = getLinearDepth(depthTexture.sample(linearSampler, position), cameraNear, cameraFar);
82
83 float bilateral = bilateralWeight(depth, textureDepth);
84 bilateral *= weight;
85 sum += color * bilateral;
86 totalWeight += bilateral;
87}
88
89fragment float4 blurFragment(
90 BlurVarying in [[stage_in]],
91 texture2d<float> sourceTexture [[texture(0)]],
92 depth2d<float> depthTexture [[texture(1)]],
93 sampler linearSampler [[sampler(0)]],
94 constant BlurUniforms& uniforms [[buffer(5)]])
95{
96 const float2 uv = clamp(in.uv, float2(0.0), float2(1.0));
97
98 // handle the center pixel separately because it doesn't participate in bilateral filtering
99 float depth = getLinearDepth(depthTexture.sample(linearSampler, uv), uniforms.cameraNear, uniforms.cameraFar);
100 float totalWeight = 1.0;
101 float color = sourceTexture.sample(linearSampler, uv).r;
102 float sum = color * totalWeight;
103
104 // Gaussian sigma: filterSize / 3 gives ~99.7% of the bell within the kernel
105 float sigma = max(float(uniforms.filterSize) / 3.0, 1.0);
106 float invSigma2 = 1.0 / (2.0 * sigma * sigma);
107
108 for (int i = -uniforms.filterSize; i <= uniforms.filterSize; i++) {
109 float weight = exp(-float(i * i) * invSigma2);
110
111 #ifdef HORIZONTAL
112 float2 offset = float2(float(i), 0.0) * uniforms.sourceInvResolution;
113 #else
114 float2 offset = float2(0.0, float(i)) * uniforms.sourceInvResolution;
115 #endif
116
117 tap(sum, totalWeight, weight, depth, uv + offset, sourceTexture, depthTexture, linearSampler,
118 uniforms.cameraNear, uniforms.cameraFar);
119 }
120
121 float ao = sum / totalWeight;
122 return float4(ao, 0.0, 0.0, 1.0);
123}
124)";
125
126 constexpr const char* BLUR_SOURCE_VERTICAL = R"(
127#include <metal_stdlib>
128using namespace metal;
129
130struct ComposeVertexIn {
131 float3 position [[attribute(0)]];
132 float3 normal [[attribute(1)]];
133 float2 uv0 [[attribute(2)]];
134 float4 tangent [[attribute(3)]];
135 float2 uv1 [[attribute(4)]];
136};
137
138struct BlurVarying {
139 float4 position [[position]];
140 float2 uv;
141};
142
143struct BlurUniforms {
144 float2 sourceInvResolution;
145 int filterSize;
146 float cameraNear;
147 float cameraFar;
148};
149
150vertex BlurVarying blurVertex(ComposeVertexIn in [[stage_in]])
151{
152 BlurVarying out;
153 out.position = float4(in.position, 1.0);
154 out.uv = in.uv0;
155 return out;
156}
157
158static inline float getLinearDepth(float rawDepth, float cameraNear, float cameraFar)
159{
160 return (cameraNear * cameraFar) / (cameraFar - rawDepth * (cameraFar - cameraNear));
161}
162
163static inline float bilateralWeight(float depth, float sampleDepth)
164{
165 float diff = (sampleDepth - depth);
166 return max(0.0, 1.0 - diff * diff);
167}
168
169static inline void tap(thread float& sum, thread float& totalWeight, float weight, float depth,
170 float2 position, texture2d<float> sourceTexture, depth2d<float> depthTexture,
171 sampler linearSampler, float cameraNear, float cameraFar)
172{
173 float color = sourceTexture.sample(linearSampler, position).r;
174 float textureDepth = getLinearDepth(depthTexture.sample(linearSampler, position), cameraNear, cameraFar);
175
176 float bilateral = bilateralWeight(depth, textureDepth);
177 bilateral *= weight;
178 sum += color * bilateral;
179 totalWeight += bilateral;
180}
181
182fragment float4 blurFragment(
183 BlurVarying in [[stage_in]],
184 texture2d<float> sourceTexture [[texture(0)]],
185 depth2d<float> depthTexture [[texture(1)]],
186 sampler linearSampler [[sampler(0)]],
187 constant BlurUniforms& uniforms [[buffer(5)]])
188{
189 const float2 uv = clamp(in.uv, float2(0.0), float2(1.0));
190
191 float depth = getLinearDepth(depthTexture.sample(linearSampler, uv), uniforms.cameraNear, uniforms.cameraFar);
192 float totalWeight = 1.0;
193 float color = sourceTexture.sample(linearSampler, uv).r;
194 float sum = color * totalWeight;
195
196 // Gaussian sigma: filterSize / 3 gives ~99.7% of the bell within the kernel
197 float sigma = max(float(uniforms.filterSize) / 3.0, 1.0);
198 float invSigma2 = 1.0 / (2.0 * sigma * sigma);
199
200 for (int i = -uniforms.filterSize; i <= uniforms.filterSize; i++) {
201 float weight = exp(-float(i * i) * invSigma2);
202
203 // Vertical: offset along Y axis
204 float2 offset = float2(0.0, float(i)) * uniforms.sourceInvResolution;
205
206 tap(sum, totalWeight, weight, depth, uv + offset, sourceTexture, depthTexture, linearSampler,
207 uniforms.cameraNear, uniforms.cameraFar);
208 }
209
210 float ao = sum / totalWeight;
211 return float4(ao, 0.0, 0.0, 1.0);
212}
213)";
214 }
215
217 : _device(device), _composePass(composePass), _horizontal(horizontal)
218 {
219 }
220
222 {
223 if (_depthStencilState) {
224 _depthStencilState->release();
225 _depthStencilState = nullptr;
226 }
227 }
228
229 void MetalDepthAwareBlurPass::ensureResources()
230 {
231 // Ensure the compose pass's shared vertex buffer/format are created first
232 _composePass->ensureResources();
233
234 if (_shader && _composePass->vertexBuffer() && _composePass->vertexFormat() &&
235 _blendState && _depthState && _depthStencilState) {
236 return;
237 }
238
239 if (!_shader) {
240 ShaderDefinition definition;
241 definition.name = _horizontal ? "DepthAwareBlurHorizontalPass" : "DepthAwareBlurVerticalPass";
242 definition.vshader = "blurVertex";
243 definition.fshader = "blurFragment";
244 const char* source = _horizontal ? BLUR_SOURCE_HORIZONTAL : BLUR_SOURCE_VERTICAL;
245 _shader = createShader(_device, definition, source);
246 }
247
248 if (!_blendState) {
249 _blendState = std::make_shared<BlendState>();
250 }
251 if (!_depthState) {
252 _depthState = std::make_shared<DepthState>();
253 }
254 if (!_depthStencilState && _device->raw()) {
255 auto* depthDesc = MTL::DepthStencilDescriptor::alloc()->init();
256 depthDesc->setDepthCompareFunction(MTL::CompareFunctionAlways);
257 depthDesc->setDepthWriteEnabled(false);
258 _depthStencilState = _device->raw()->newDepthStencilState(depthDesc);
259 depthDesc->release();
260 }
261 }
262
263 void MetalDepthAwareBlurPass::execute(MTL::RenderCommandEncoder* encoder,
264 const DepthAwareBlurPassParams& params,
265 MetalRenderPipeline* pipeline, const std::shared_ptr<RenderTarget>& renderTarget,
266 const std::vector<std::shared_ptr<MetalBindGroupFormat>>& bindGroupFormats,
267 MTL::SamplerState* defaultSampler, MTL::DepthStencilState* defaultDepthStencilState)
268 {
269 if (!encoder || !params.sourceTexture || !params.depthTexture) {
270 return;
271 }
272
273 ensureResources();
274 if (!_shader || !_composePass->vertexBuffer() || !_composePass->vertexFormat() || !_blendState || !_depthState) {
275 spdlog::warn("[executeDepthAwareBlurPass] missing blur resources");
276 return;
277 }
278
279 Primitive primitive;
280 primitive.type = PRIMITIVE_TRIANGLES;
281 primitive.base = 0;
282 primitive.count = 3;
283 primitive.indexed = false;
284
285 auto pipelineState = pipeline->get(primitive, _composePass->vertexFormat(), nullptr, -1, _shader, renderTarget,
286 bindGroupFormats, _blendState, _depthState, CullMode::CULLFACE_NONE, false, nullptr, nullptr);
287 if (!pipelineState) {
288 spdlog::warn("[executeDepthAwareBlurPass] failed to get pipeline state");
289 return;
290 }
291
292 auto* vb = dynamic_cast<MetalVertexBuffer*>(_composePass->vertexBuffer().get());
293 if (!vb || !vb->raw()) {
294 spdlog::warn("[executeDepthAwareBlurPass] missing vertex buffer");
295 return;
296 }
297
298 encoder->setRenderPipelineState(pipelineState);
299 encoder->setCullMode(MTL::CullModeNone);
300 encoder->setDepthStencilState(_depthStencilState ? _depthStencilState : defaultDepthStencilState);
301 encoder->setVertexBuffer(vb->raw(), 0, 0);
302
303 auto* sourceHw = dynamic_cast<gpu::MetalTexture*>(params.sourceTexture->impl());
304 auto* depthHw = dynamic_cast<gpu::MetalTexture*>(params.depthTexture->impl());
305
306 encoder->setFragmentTexture(sourceHw ? sourceHw->raw() : nullptr, 0);
307 encoder->setFragmentTexture(depthHw ? depthHw->raw() : nullptr, 1);
308 if (defaultSampler) {
309 encoder->setFragmentSamplerState(defaultSampler, 0);
310 }
311
312 struct alignas(16) BlurUniforms
313 {
314 float sourceInvResolution[2];
315 int32_t filterSize;
316 float cameraNear;
317 float cameraFar;
318 } uniforms{};
319
320 uniforms.sourceInvResolution[0] = params.sourceInvResolutionX;
321 uniforms.sourceInvResolution[1] = params.sourceInvResolutionY;
322 uniforms.filterSize = params.filterSize;
323 uniforms.cameraNear = params.cameraNear;
324 uniforms.cameraFar = params.cameraFar;
325 encoder->setFragmentBytes(&uniforms, sizeof(BlurUniforms), 5);
326
327 encoder->drawPrimitives(MTL::PrimitiveTypeTriangle, static_cast<NS::UInteger>(0),
328 static_cast<NS::UInteger>(3));
329 _device->recordDrawCall();
330 }
331}
std::shared_ptr< VertexFormat > vertexFormat() const
Shared vertex format (full-screen triangle, 14 floats per vertex).
std::shared_ptr< VertexBuffer > vertexBuffer() const
Shared vertex buffer (3-vertex full-screen triangle).
MetalDepthAwareBlurPass(MetalGraphicsDevice *device, MetalComposePass *composePass, bool horizontal)
void execute(MTL::RenderCommandEncoder *encoder, const DepthAwareBlurPassParams &params, MetalRenderPipeline *pipeline, const std::shared_ptr< RenderTarget > &renderTarget, const std::vector< std::shared_ptr< MetalBindGroupFormat > > &bindGroupFormats, MTL::SamplerState *defaultSampler, MTL::DepthStencilState *defaultDepthStencilState)
Execute the blur pass on the active render command encoder.
MTL::RenderPipelineState * get(const Primitive &primitive, const std::shared_ptr< VertexFormat > &vertexFormat0, const std::shared_ptr< VertexFormat > &vertexFormat1, int ibFormat, const std::shared_ptr< Shader > &shader, const std::shared_ptr< RenderTarget > &renderTarget, const std::vector< std::shared_ptr< MetalBindGroupFormat > > &bindGroupFormats, const std::shared_ptr< BlendState > &blendState, const std::shared_ptr< DepthState > &depthState, CullMode cullMode, bool stencilEnabled, const std::shared_ptr< StencilParameters > &stencilFront, const std::shared_ptr< StencilParameters > &stencilBack, const std::shared_ptr< VertexFormat > &instancingFormat=nullptr)
gpu::HardwareTexture * impl() const
Definition texture.h:101
std::shared_ptr< Shader > createShader(GraphicsDevice *graphicsDevice, const ShaderDefinition &definition, const std::string &sourceCode)
Definition shader.cpp:39
@ PRIMITIVE_TRIANGLES
Definition mesh.h:23
Describes how vertex and index data should be interpreted for a draw call.
Definition mesh.h:33
PrimitiveType type
Definition mesh.h:34