VisuTwin Canvas
C++ 3D Engine — Metal Backend
Loading...
Searching...
No Matches
metalComposePass.cpp
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2025-2026 Arnis Lektauers
3//
4// Compose post-processing pass implementation.
5// Extracted from MetalGraphicsDevice.
6//
7#include "metalComposePass.h"
8
9#include <cstring>
10#include "metalGraphicsDevice.h"
11#include "metalRenderPipeline.h"
12#include "metalTexture.h"
13#include "metalVertexBuffer.h"
21#include "spdlog/spdlog.h"
22
23namespace visutwin::canvas
24{
25 namespace
26 {
27 constexpr const char* COMPOSE_SOURCE = R"(
28#include <metal_stdlib>
29using namespace metal;
30
31struct ComposeVertexIn {
32 float3 position [[attribute(0)]];
33 float3 normal [[attribute(1)]];
34 float2 uv0 [[attribute(2)]];
35 float4 tangent [[attribute(3)]];
36 float2 uv1 [[attribute(4)]];
37};
38
39struct ComposeVarying {
40 float4 position [[position]];
41 float2 uv;
42};
43
44struct ComposeUniforms {
45 uint dofEnabled;
46 uint taaEnabled;
47 uint ssaoEnabled;
48 uint bloomEnabled;
49 uint blurTextureUpscale;
50 float bloomIntensity;
51 float dofIntensity;
52 float sharpness;
53 uint tonemapMode;
54 float exposure;
55 float2 sceneTextureInvRes;
56 // Single-pass DOF parameters
57 float dofFocusDistance;
58 float dofFocusRange;
59 float dofBlurRadius;
60 float dofCameraNear;
61 float dofCameraFar;
62 float _pad0; // padding to maintain 8-byte alignment for next field
63 // Vignette (use float4 for color to match C++ alignment)
64 uint vignetteEnabled;
65 float vignetteInner;
66 float vignetteOuter;
67 float vignetteCurvature;
68 float vignetteIntensity;
69 float vignetteColorR;
70 float vignetteColorG;
71 float vignetteColorB;
72};
73
74float3 toneMapLinear(float3 color, float exposure) {
75 return color * exposure;
76}
77
78float3 toneMapAces(float3 color, float exposure) {
79 const float tA = 2.51;
80 const float tB = 0.03;
81 const float tC = 2.43;
82 const float tD = 0.59;
83 const float tE = 0.14;
84 float3 x = color * exposure;
85 return (x * (tA * x + tB)) / (x * (tC * x + tD) + tE);
86}
87
88// https://modelviewer.dev/examples/tone-mapping
89float3 toneMapNeutral(float3 color, float exposure) {
90 color *= exposure;
91
92 float startCompression = 0.8 - 0.04;
93 float desaturation = 0.15;
94
95 float x = min(color.r, min(color.g, color.b));
96 float offset = x < 0.08 ? x - 6.25 * x * x : 0.04;
97 color -= offset;
98
99 float peak = max(color.r, max(color.g, color.b));
100 if (peak < startCompression) return color;
101
102 float d = 1.0 - startCompression;
103 float newPeak = 1.0 - d * d / (peak + d - startCompression);
104 color *= newPeak / peak;
105
106 float g = 1.0 - 1.0 / (desaturation * (peak - newPeak) + 1.0);
107 return mix(color, float3(newPeak), g);
108}
109
110float maxComp(float x, float y, float z) { return max(x, max(y, z)); }
111float3 toSDR(float3 c) { return c / (1.0 + maxComp(c.r, c.g, c.b)); }
112float3 toHDR(float3 c) { return c / (1.0 - maxComp(c.r, c.g, c.b)); }
113
114float3 applyCas(float3 color, float2 uv, float sharpness,
115 texture2d<float> sceneTexture, sampler s, float2 invRes) {
116 float3 a = toSDR(sceneTexture.sample(s, uv + float2(0.0, -invRes.y)).rgb);
117 float3 b = toSDR(sceneTexture.sample(s, uv + float2(-invRes.x, 0.0)).rgb);
118 float3 c = toSDR(color);
119 float3 d = toSDR(sceneTexture.sample(s, uv + float2(invRes.x, 0.0)).rgb);
120 float3 e = toSDR(sceneTexture.sample(s, uv + float2(0.0, invRes.y)).rgb);
121
122 float min_g = min(a.g, min(b.g, min(c.g, min(d.g, e.g))));
123 float max_g = max(a.g, max(b.g, max(c.g, max(d.g, e.g))));
124 float sharpening_amount = sqrt(min(1.0 - max_g, min_g) / max_g);
125 float w = sharpening_amount * sharpness;
126 float3 res = (w * (a + b + d + e) + c) / (4.0 * w + 1.0);
127 return toHDR(max(res, float3(0.0)));
128}
129
130vertex ComposeVarying composeVertex(ComposeVertexIn in [[stage_in]])
131{
132 ComposeVarying out;
133 out.position = float4(in.position, 1.0);
134 out.uv = in.uv0;
135 return out;
136}
137
138// Vignette: darken edges with configurable curvature and inner/outer radii
139float3 applyVignette(float3 color, float2 uv, float inner, float outer,
140 float curvature, float intensity, float3 vigColor) {
141 float2 curve = pow(abs(uv * 2.0 - 1.0), float2(1.0 / curvature));
142 float edge = pow(length(curve), curvature);
143 float vignette = 1.0 - intensity * smoothstep(inner, outer, edge);
144 return mix(vigColor, color, vignette);
145}
146
147// Single-pass DOF using depth buffer
148float3 applyDofSinglePass(float3 sharpColor, float2 uv, float2 invRes,
149 texture2d<float> sceneTexture, depth2d<float> depthTexture, sampler s,
150 float focusDistance, float focusRange, float blurRadius,
151 float cameraNear, float cameraFar)
152{
153 float rawDepth = depthTexture.sample(s, uv);
154 float linearDepth = (cameraNear * cameraFar) / (cameraFar - rawDepth * (cameraFar - cameraNear));
155
156 // PlayCanvas-style CoC: far range starts at focusDistance + focusRange/2
157 float farRange = focusDistance + focusRange * 0.5;
158 float invRange = 1.0 / max(focusRange, 0.001);
159 float cocFar = clamp((linearDepth - farRange) * invRange, 0.0, 1.0);
160
161 if (cocFar < 0.005) return sharpColor; // early out for in-focus pixels
162
163 // Disc blur with 12 taps (Poisson-like distribution)
164 const float2 offsets[12] = {
165 float2(-0.326, -0.406), float2(-0.840, -0.074), float2(-0.696, 0.457),
166 float2(-0.203, 0.621), float2( 0.962, -0.195), float2( 0.473, -0.480),
167 float2( 0.519, 0.767), float2( 0.185, -0.893), float2( 0.507, 0.064),
168 float2(-0.321, -0.882), float2(-0.860, 0.370), float2( 0.871, 0.414)
169 };
170
171 float2 step = cocFar * blurRadius * invRes;
172 float3 sum = float3(0.0);
173 float totalWeight = 0.0;
174
175 for (int i = 0; i < 12; i++) {
176 float2 sampleUV = clamp(uv + offsets[i] * step, float2(0.0), float2(1.0));
177
178 // Read depth at sample position to compute its CoC
179 float sampleRawDepth = depthTexture.sample(s, sampleUV);
180 float sampleLinearDepth = (cameraNear * cameraFar) / (cameraFar - sampleRawDepth * (cameraFar - cameraNear));
181 float sampleCoc = clamp((sampleLinearDepth - farRange) * invRange, 0.0, 1.0);
182
183 // Weight: only blur samples that are also out of focus (prevents sharp foreground leaking)
184 float w = sampleCoc;
185 float3 tap = sceneTexture.sample(s, sampleUV).rgb;
186 sum += tap * w;
187 totalWeight += w;
188 }
189
190 float3 blurColor = (totalWeight > 0.0) ? sum / totalWeight : sharpColor;
191 return mix(sharpColor, blurColor, cocFar);
192}
193
194// Compose pass order: CAS -> SSAO -> DOF -> Bloom -> ToneMap -> Vignette
195fragment float4 composeFragment(
196 ComposeVarying in [[stage_in]],
197 texture2d<float> sceneTexture [[texture(0)]],
198 texture2d<float> bloomTexture [[texture(1)]],
199 texture2d<float> cocTexture [[texture(2)]],
200 texture2d<float> blurTexture [[texture(3)]],
201 texture2d<float> ssaoTexture [[texture(4)]],
202 depth2d<float> depthTexture [[texture(5)]],
203 sampler linearSampler [[sampler(0)]],
204 constant ComposeUniforms& uniforms [[buffer(5)]])
205{
206 const float2 uv = clamp(in.uv, float2(0.0), float2(1.0));
207 float3 result = sceneTexture.sample(linearSampler, uv).rgb;
208
209 // 1. CAS (Contrast Adaptive Sharpening)
210 if (uniforms.sharpness > 0.0) {
211 result = applyCas(result, uv, uniforms.sharpness, sceneTexture, linearSampler, uniforms.sceneTextureInvRes);
212 }
213
214 // 2. SSAO
215 if (uniforms.ssaoEnabled != 0u && ssaoTexture.get_width() > 0) {
216 const float ssao = clamp(ssaoTexture.sample(linearSampler, uv).r, 0.0, 1.0);
217 result *= ssao;
218 }
219
220 // 3. DOF (single-pass from depth buffer)
221 if (uniforms.dofEnabled != 0u) {
222 result = applyDofSinglePass(result, uv, uniforms.sceneTextureInvRes,
223 sceneTexture, depthTexture, linearSampler,
224 uniforms.dofFocusDistance, uniforms.dofFocusRange, uniforms.dofBlurRadius,
225 uniforms.dofCameraNear, uniforms.dofCameraFar);
226 }
227 // Legacy multi-pass DOF (kept as dead code for future use):
228 // if (uniforms.dofEnabled != 0u && cocTexture.get_width() > 0 && blurTexture.get_width() > 0) {
229 // const float2 coc = cocTexture.sample(linearSampler, uv).rg;
230 // const float cocAmount = clamp(max(coc.r, coc.g), 0.0, 1.0);
231 // const float3 blurColor = blurTexture.sample(linearSampler, uv).rgb;
232 // result = mix(result, blurColor, cocAmount * clamp(uniforms.dofIntensity, 0.0, 1.0));
233 // }
234
235 // 4. Bloom
236 if (uniforms.bloomEnabled != 0u && bloomTexture.get_width() > 0) {
237 const float3 bloomColor = bloomTexture.sample(linearSampler, uv).rgb;
238 result += bloomColor * max(uniforms.bloomIntensity, 0.0);
239 }
240
241 // 5. Tonemapping (tonemapping dispatch)
242 result = max(result, float3(0.0));
243 if (uniforms.tonemapMode == 3u) { // TONEMAP_ACES
244 result = toneMapAces(result, uniforms.exposure);
245 } else if (uniforms.tonemapMode == 5u) { // TONEMAP_NEUTRAL
246 result = toneMapNeutral(result, uniforms.exposure);
247 } else if (uniforms.tonemapMode == 6u) { // TONEMAP_NONE
248 // no-op
249 } else { // TONEMAP_LINEAR (default)
250 result = toneMapLinear(result, uniforms.exposure);
251 }
252
253 // 6. Vignette (applied in tonemapped linear space, before gamma)
254 if (uniforms.vignetteEnabled != 0u) {
255 float3 vigColor = float3(uniforms.vignetteColorR, uniforms.vignetteColorG, uniforms.vignetteColorB);
256 result = applyVignette(result, uv, uniforms.vignetteInner, uniforms.vignetteOuter,
257 uniforms.vignetteCurvature, uniforms.vignetteIntensity,
258 vigColor);
259 }
260
261 // 7. Gamma correction (gammaCorrectOutput)
262 // The back buffer is BGRA8Unorm (not sRGB), so we must apply gamma in the shader.
263 result = pow(max(result, float3(0.0)) + 0.0000001, float3(1.0 / 2.2));
264
265 return float4(result, 1.0);
266}
267)";
268 }
269
271 : _device(device)
272 {
273 }
274
276 {
277 if (_depthStencilState) {
278 _depthStencilState->release();
279 _depthStencilState = nullptr;
280 }
281 }
282
284 {
285 if (_shader && _vertexBuffer && _vertexFormat && _blendState &&
286 _depthState && _depthStencilState) {
287 return;
288 }
289
290 if (!_shader) {
291 ShaderDefinition definition;
292 definition.name = "ComposePass";
293 definition.vshader = "composeVertex";
294 definition.fshader = "composeFragment";
295 _shader = createShader(_device, definition, COMPOSE_SOURCE);
296 }
297
298 if (!_vertexFormat) {
299 _vertexFormat = std::make_shared<VertexFormat>(static_cast<int>(14 * sizeof(float)), true, false);
300 }
301
302 if (!_vertexBuffer && _vertexFormat) {
303 // DEVIATION: Metal/WebGPU texture UV origin is top-left (V=0 at top).
304 // Upstream handles this via getImageEffectUV() Y-flip in shader.
305 // We flip UV.y here: clip Y=-1 (bottom) -> UV.y=1 (bottom of texture),
306 // clip Y=+1 (top) -> UV.y=0 (top of texture).
307 constexpr float vertexData[3 * 14] = {
308 // pos.xyz normal.xyz uv0.xy tangent.xyzw uv1.xy
309 -1.0f, -1.0f, 0.0f, 0, 0, 1, 0.0f, 1.0f, 1, 0, 0, 1, 0.0f, 1.0f,
310 3.0f, -1.0f, 0.0f, 0, 0, 1, 2.0f, 1.0f, 1, 0, 0, 1, 0.0f, 1.0f,
311 -1.0f, 3.0f, 0.0f, 0, 0, 1, 0.0f,-1.0f, 1, 0, 0, 1, 0.0f,-1.0f
312 };
313 VertexBufferOptions options;
314 options.usage = BUFFER_STATIC;
315 options.data.resize(sizeof(vertexData));
316 std::memcpy(options.data.data(), vertexData, sizeof(vertexData));
317 _vertexBuffer = _device->createVertexBuffer(_vertexFormat, 3, options);
318 }
319
320 if (!_blendState) {
321 _blendState = std::make_shared<BlendState>();
322 }
323 if (!_depthState) {
324 _depthState = std::make_shared<DepthState>();
325 }
326 if (!_depthStencilState && _device->raw()) {
327 auto* depthDesc = MTL::DepthStencilDescriptor::alloc()->init();
328 depthDesc->setDepthCompareFunction(MTL::CompareFunctionAlways);
329 depthDesc->setDepthWriteEnabled(false);
330 _depthStencilState = _device->raw()->newDepthStencilState(depthDesc);
331 depthDesc->release();
332 }
333 }
334
335 void MetalComposePass::execute(MTL::RenderCommandEncoder* encoder, const ComposePassParams& params,
336 MetalRenderPipeline* pipeline, const std::shared_ptr<RenderTarget>& renderTarget,
337 const std::vector<std::shared_ptr<MetalBindGroupFormat>>& bindGroupFormats,
338 MTL::SamplerState* defaultSampler)
339 {
340 if (!encoder || !params.sceneTexture) {
341 return;
342 }
344 if (!_shader || !_vertexBuffer || !_vertexFormat || !_blendState || !_depthState) {
345 return;
346 }
347
348 Primitive primitive;
349 primitive.type = PRIMITIVE_TRIANGLES;
350 primitive.base = 0;
351 primitive.count = 3;
352 primitive.indexed = false;
353
354 auto pipelineState = pipeline->get(primitive, _vertexFormat, nullptr, -1, _shader, renderTarget,
355 bindGroupFormats, _blendState, _depthState, CullMode::CULLFACE_NONE, false, nullptr, nullptr);
356 if (!pipelineState) {
357 return;
358 }
359
360 auto* vb = dynamic_cast<MetalVertexBuffer*>(_vertexBuffer.get());
361 if (!vb || !vb->raw()) {
362 return;
363 }
364
365 encoder->setRenderPipelineState(pipelineState);
366 encoder->setCullMode(MTL::CullModeNone);
367 encoder->setDepthStencilState(_depthStencilState);
368 encoder->setVertexBuffer(vb->raw(), 0, 0);
369
370 auto* sceneHw = dynamic_cast<gpu::MetalTexture*>(params.sceneTexture->impl());
371 auto* bloomHw = params.bloomTexture ? dynamic_cast<gpu::MetalTexture*>(params.bloomTexture->impl()) : nullptr;
372 auto* cocHw = params.cocTexture ? dynamic_cast<gpu::MetalTexture*>(params.cocTexture->impl()) : nullptr;
373 auto* blurHw = params.blurTexture ? dynamic_cast<gpu::MetalTexture*>(params.blurTexture->impl()) : nullptr;
374 auto* ssaoHw = params.ssaoTexture ? dynamic_cast<gpu::MetalTexture*>(params.ssaoTexture->impl()) : nullptr;
375
376 auto* depthHw = params.depthTexture ? dynamic_cast<gpu::MetalTexture*>(params.depthTexture->impl()) : nullptr;
377
378 encoder->setFragmentTexture(sceneHw ? sceneHw->raw() : nullptr, 0);
379 encoder->setFragmentTexture(bloomHw ? bloomHw->raw() : nullptr, 1);
380 encoder->setFragmentTexture(cocHw ? cocHw->raw() : nullptr, 2);
381 encoder->setFragmentTexture(blurHw ? blurHw->raw() : nullptr, 3);
382 encoder->setFragmentTexture(ssaoHw ? ssaoHw->raw() : nullptr, 4);
383 encoder->setFragmentTexture(depthHw ? depthHw->raw() : nullptr, 5);
384 if (defaultSampler) {
385 encoder->setFragmentSamplerState(defaultSampler, 0);
386 }
387
388 struct alignas(16) ComposeUniforms
389 {
390 uint32_t dofEnabled = 0u;
391 uint32_t taaEnabled = 0u;
392 uint32_t ssaoEnabled = 0u;
393 uint32_t bloomEnabled = 0u;
394 uint32_t blurTextureUpscale = 0u;
395 float bloomIntensity = 0.01f;
396 float dofIntensity = 1.0f;
397 float sharpness = 0.0f;
398 uint32_t tonemapMode = 0u;
399 float exposure = 1.0f;
400 float sceneTextureInvRes[2] = {0.0f, 0.0f};
401 // Single-pass DOF parameters
402 float dofFocusDistance = 1.0f;
403 float dofFocusRange = 0.5f;
404 float dofBlurRadius = 3.0f;
405 float dofCameraNear = 0.01f;
406 float dofCameraFar = 100.0f;
407 float _pad0 = 0.0f; // padding to maintain alignment
408 // Vignette
409 uint32_t vignetteEnabled = 0u;
410 float vignetteInner = 0.5f;
411 float vignetteOuter = 1.0f;
412 float vignetteCurvature = 0.5f;
413 float vignetteIntensity = 0.3f;
414 float vignetteColorR = 0.0f;
415 float vignetteColorG = 0.0f;
416 float vignetteColorB = 0.0f;
417 } uniforms;
418 uniforms.dofEnabled = params.dofEnabled ? 1u : 0u;
419 uniforms.taaEnabled = params.taaEnabled ? 1u : 0u;
420 uniforms.ssaoEnabled = params.ssaoTexture ? 1u : 0u;
421 uniforms.bloomEnabled = params.bloomTexture ? 1u : 0u;
422 uniforms.blurTextureUpscale = params.blurTextureUpscale ? 1u : 0u;
423 uniforms.bloomIntensity = params.bloomIntensity;
424 uniforms.dofIntensity = params.dofIntensity;
425 uniforms.sharpness = params.sharpness;
426 uniforms.tonemapMode = static_cast<uint32_t>(params.toneMapping);
427 uniforms.exposure = params.exposure;
428 if (params.sceneTexture && params.sceneTexture->width() > 0 && params.sceneTexture->height() > 0) {
429 uniforms.sceneTextureInvRes[0] = 1.0f / static_cast<float>(params.sceneTexture->width());
430 uniforms.sceneTextureInvRes[1] = 1.0f / static_cast<float>(params.sceneTexture->height());
431 }
432 // Single-pass DOF
433 uniforms.dofFocusDistance = params.dofFocusDistance;
434 uniforms.dofFocusRange = params.dofFocusRange;
435 uniforms.dofBlurRadius = params.dofBlurRadius;
436 uniforms.dofCameraNear = params.dofCameraNear;
437 uniforms.dofCameraFar = params.dofCameraFar;
438
439 // Vignette
440 uniforms.vignetteEnabled = params.vignetteEnabled ? 1u : 0u;
441 uniforms.vignetteInner = params.vignetteInner;
442 uniforms.vignetteOuter = params.vignetteOuter;
443 uniforms.vignetteCurvature = params.vignetteCurvature;
444 uniforms.vignetteIntensity = params.vignetteIntensity;
445 uniforms.vignetteColorR = params.vignetteColor[0];
446 uniforms.vignetteColorG = params.vignetteColor[1];
447 uniforms.vignetteColorB = params.vignetteColor[2];
448
449 encoder->setFragmentBytes(&uniforms, sizeof(ComposeUniforms), 5);
450 encoder->drawPrimitives(MTL::PrimitiveTypeTriangle, static_cast<NS::UInteger>(0), static_cast<NS::UInteger>(3));
451 _device->recordDrawCall();
452 }
453}
MetalComposePass(MetalGraphicsDevice *device)
void execute(MTL::RenderCommandEncoder *encoder, const ComposePassParams &params, MetalRenderPipeline *pipeline, const std::shared_ptr< RenderTarget > &renderTarget, const std::vector< std::shared_ptr< MetalBindGroupFormat > > &bindGroupFormats, MTL::SamplerState *defaultSampler)
Execute the compose pass on the active render command encoder.
MTL::RenderPipelineState * get(const Primitive &primitive, const std::shared_ptr< VertexFormat > &vertexFormat0, const std::shared_ptr< VertexFormat > &vertexFormat1, int ibFormat, const std::shared_ptr< Shader > &shader, const std::shared_ptr< RenderTarget > &renderTarget, const std::vector< std::shared_ptr< MetalBindGroupFormat > > &bindGroupFormats, const std::shared_ptr< BlendState > &blendState, const std::shared_ptr< DepthState > &depthState, CullMode cullMode, bool stencilEnabled, const std::shared_ptr< StencilParameters > &stencilFront, const std::shared_ptr< StencilParameters > &stencilBack, const std::shared_ptr< VertexFormat > &instancingFormat=nullptr)
uint32_t width() const
Definition texture.h:63
uint32_t height() const
Definition texture.h:65
gpu::HardwareTexture * impl() const
Definition texture.h:101
std::shared_ptr< Shader > createShader(GraphicsDevice *graphicsDevice, const ShaderDefinition &definition, const std::string &sourceCode)
Definition shader.cpp:39
@ PRIMITIVE_TRIANGLES
Definition mesh.h:23
Describes how vertex and index data should be interpreted for a draw call.
Definition mesh.h:33
PrimitiveType type
Definition mesh.h:34