VisuTwin Canvas
C++ 3D Engine — Metal Backend
Loading...
Searching...
No Matches
metalDofBlurPass.cpp
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2025-2026 Arnis Lektauers
3//
4// DOF (Depth of Field) blur pass implementation.
5// Samples scene in concentric rings weighted by Circle of Confusion.
6//
7#include "metalDofBlurPass.h"
8
9#include "metalComposePass.h"
10#include "metalGraphicsDevice.h"
11#include "metalRenderPipeline.h"
12#include "metalTexture.h"
13#include "metalVertexBuffer.h"
22#include "spdlog/spdlog.h"
23
24namespace visutwin::canvas
25{
26 namespace
27 {
28 // DOF Blur shader — disc blur / bokeh approximation.
29 // Samples far texture in concentric rings, weighted by CoC value.
30 constexpr const char* DOF_BLUR_SOURCE = R"(
31#include <metal_stdlib>
32using namespace metal;
33
34struct ComposeVertexIn {
35 float3 position [[attribute(0)]];
36 float3 normal [[attribute(1)]];
37 float2 uv0 [[attribute(2)]];
38 float4 tangent [[attribute(3)]];
39 float2 uv1 [[attribute(4)]];
40};
41
42struct DofBlurVarying {
43 float4 position [[position]];
44 float2 uv;
45};
46
47struct DofBlurUniforms {
48 float blurRadiusNear; // offset 0
49 float blurRadiusFar; // offset 4
50 float2 invResolution; // offset 8 (Metal float2: 8-byte aligned)
51 int blurRings; // offset 16
52 int blurRingPoints; // offset 20
53};
54
55vertex DofBlurVarying dofBlurVertex(ComposeVertexIn in [[stage_in]])
56{
57 DofBlurVarying out;
58 out.position = float4(in.position, 1.0);
59 out.uv = in.uv0;
60 return out;
61}
62
63fragment float4 dofBlurFragment(
64 DofBlurVarying in [[stage_in]],
65 texture2d<float> farTexture [[texture(0)]],
66 texture2d<float> cocTexture [[texture(1)]],
67 texture2d<float> nearTexture [[texture(2)]],
68 sampler linearSampler [[sampler(0)]],
69 constant DofBlurUniforms& uniforms [[buffer(5)]])
70{
71 float2 uv = clamp(in.uv, float2(0.0), float2(1.0));
72 float2 coc = cocTexture.sample(linearSampler, uv).rg;
73
74 float3 farColor = float3(0.0);
75 float farWeight = 0.0;
76
77 // Concentric disc sampling for far blur
78 float blurRadius = coc.r * uniforms.blurRadiusFar;
79 int rings = max(uniforms.blurRings, 1);
80 int ringPoints = max(uniforms.blurRingPoints, 1);
81
82 for (int ring = 1; ring <= rings; ring++) {
83 float ringRadius = float(ring) / float(rings);
84 int pointsInRing = ring * ringPoints;
85 for (int p = 0; p < pointsInRing; p++) {
86 float angle = float(p) * 6.283185 / float(pointsInRing);
87 float2 offset = float2(cos(angle), sin(angle)) * ringRadius * blurRadius;
88 float2 sampleUV = uv + offset * uniforms.invResolution;
89 sampleUV = clamp(sampleUV, float2(0.0), float2(1.0));
90 float sampleCoc = cocTexture.sample(linearSampler, sampleUV).r;
91 float w = sampleCoc; // weight by CoC at sample location
92 farColor += farTexture.sample(linearSampler, sampleUV).rgb * w;
93 farWeight += w;
94 }
95 }
96
97 if (farWeight > 0.0) {
98 farColor /= farWeight;
99 } else {
100 farColor = farTexture.sample(linearSampler, uv).rgb;
101 }
102
103 // Near blur (optional, blended when CoC near channel > 0)
104 float3 result = farColor;
105 if (coc.g > 0.0) {
106 float3 nearColor = float3(0.0);
107 float nearWeight = 0.0;
108 float nearBlurRadius = coc.g * uniforms.blurRadiusNear;
109
110 for (int ring = 1; ring <= rings; ring++) {
111 float ringRadius = float(ring) / float(rings);
112 int pointsInRing = ring * ringPoints;
113 for (int p = 0; p < pointsInRing; p++) {
114 float angle = float(p) * 6.283185 / float(pointsInRing);
115 float2 offset = float2(cos(angle), sin(angle)) * ringRadius * nearBlurRadius;
116 float2 sampleUV = uv + offset * uniforms.invResolution;
117 sampleUV = clamp(sampleUV, float2(0.0), float2(1.0));
118 float sampleCocNear = cocTexture.sample(linearSampler, sampleUV).g;
119 float w = sampleCocNear;
120 nearColor += nearTexture.sample(linearSampler, sampleUV).rgb * w;
121 nearWeight += w;
122 }
123 }
124
125 if (nearWeight > 0.0) {
126 nearColor /= nearWeight;
127 result = mix(result, nearColor, coc.g);
128 }
129 }
130
131 return float4(result, 1.0);
132}
133)";
134 }
135
137 : _device(device), _composePass(composePass)
138 {
139 }
140
142 {
143 if (_depthStencilState) {
144 _depthStencilState->release();
145 _depthStencilState = nullptr;
146 }
147 }
148
149 void MetalDofBlurPass::ensureResources()
150 {
151 // Ensure the compose pass's shared vertex buffer/format are created first
152 _composePass->ensureResources();
153
154 if (_shader && _composePass->vertexBuffer() && _composePass->vertexFormat() &&
155 _blendState && _depthState && _depthStencilState) {
156 return;
157 }
158
159 if (!_shader) {
160 ShaderDefinition definition;
161 definition.name = "DofBlurPass";
162 definition.vshader = "dofBlurVertex";
163 definition.fshader = "dofBlurFragment";
164 _shader = createShader(_device, definition, DOF_BLUR_SOURCE);
165 }
166
167 if (!_blendState) {
168 _blendState = std::make_shared<BlendState>();
169 }
170 if (!_depthState) {
171 _depthState = std::make_shared<DepthState>();
172 }
173 if (!_depthStencilState && _device->raw()) {
174 auto* depthDesc = MTL::DepthStencilDescriptor::alloc()->init();
175 depthDesc->setDepthCompareFunction(MTL::CompareFunctionAlways);
176 depthDesc->setDepthWriteEnabled(false);
177 _depthStencilState = _device->raw()->newDepthStencilState(depthDesc);
178 depthDesc->release();
179 }
180 }
181
182 void MetalDofBlurPass::execute(MTL::RenderCommandEncoder* encoder,
183 const DofBlurPassParams& params,
184 MetalRenderPipeline* pipeline, const std::shared_ptr<RenderTarget>& renderTarget,
185 const std::vector<std::shared_ptr<MetalBindGroupFormat>>& bindGroupFormats,
186 MTL::SamplerState* defaultSampler, MTL::DepthStencilState* defaultDepthStencilState)
187 {
188 if (!encoder || !params.farTexture || !params.cocTexture) {
189 return;
190 }
191
192 ensureResources();
193 if (!_shader || !_composePass->vertexBuffer() || !_composePass->vertexFormat() || !_blendState || !_depthState) {
194 spdlog::warn("[executeDofBlurPass] missing DOF blur resources");
195 return;
196 }
197
198 Primitive primitive;
199 primitive.type = PRIMITIVE_TRIANGLES;
200 primitive.base = 0;
201 primitive.count = 3;
202 primitive.indexed = false;
203
204 auto pipelineState = pipeline->get(primitive, _composePass->vertexFormat(), nullptr, -1, _shader, renderTarget,
205 bindGroupFormats, _blendState, _depthState, CullMode::CULLFACE_NONE, false, nullptr, nullptr);
206 if (!pipelineState) {
207 spdlog::warn("[executeDofBlurPass] failed to get pipeline state");
208 return;
209 }
210
211 auto* vb = dynamic_cast<MetalVertexBuffer*>(_composePass->vertexBuffer().get());
212 if (!vb || !vb->raw()) {
213 spdlog::warn("[executeDofBlurPass] missing vertex buffer");
214 return;
215 }
216
217 encoder->setRenderPipelineState(pipelineState);
218 encoder->setCullMode(MTL::CullModeNone);
219 encoder->setDepthStencilState(_depthStencilState ? _depthStencilState : defaultDepthStencilState);
220 encoder->setVertexBuffer(vb->raw(), 0, 0);
221
222 // Bind textures: slot 0 = far, slot 1 = CoC, slot 2 = near
223 auto* farHw = dynamic_cast<gpu::MetalTexture*>(params.farTexture->impl());
224 encoder->setFragmentTexture(farHw ? farHw->raw() : nullptr, 0);
225
226 auto* cocHw = dynamic_cast<gpu::MetalTexture*>(params.cocTexture->impl());
227 encoder->setFragmentTexture(cocHw ? cocHw->raw() : nullptr, 1);
228
229 if (params.nearTexture) {
230 auto* nearHw = dynamic_cast<gpu::MetalTexture*>(params.nearTexture->impl());
231 encoder->setFragmentTexture(nearHw ? nearHw->raw() : nullptr, 2);
232 }
233
234 if (defaultSampler) {
235 encoder->setFragmentSamplerState(defaultSampler, 0);
236 }
237
238 // DofBlurUniforms — float2 invResolution has 8-byte alignment, so pad after blurRadiusFar.
239 struct alignas(16) DofBlurUniforms
240 {
241 float blurRadiusNear; // offset 0
242 float blurRadiusFar; // offset 4
243 float invResolution[2]; // offset 8 (matches Metal float2)
244 int32_t blurRings; // offset 16
245 int32_t blurRingPoints; // offset 20
246 } uniforms{};
247
248 uniforms.blurRadiusNear = params.blurRadiusNear;
249 uniforms.blurRadiusFar = params.blurRadiusFar;
250 uniforms.invResolution[0] = params.invResolutionX;
251 uniforms.invResolution[1] = params.invResolutionY;
252 uniforms.blurRings = params.blurRings;
253 uniforms.blurRingPoints = params.blurRingPoints;
254 encoder->setFragmentBytes(&uniforms, sizeof(DofBlurUniforms), 5);
255
256 encoder->drawPrimitives(MTL::PrimitiveTypeTriangle, static_cast<NS::UInteger>(0),
257 static_cast<NS::UInteger>(3));
258 _device->recordDrawCall();
259 }
260}
std::shared_ptr< VertexFormat > vertexFormat() const
Shared vertex format (full-screen triangle, 14 floats per vertex).
std::shared_ptr< VertexBuffer > vertexBuffer() const
Shared vertex buffer (3-vertex full-screen triangle).
MetalDofBlurPass(MetalGraphicsDevice *device, MetalComposePass *composePass)
void execute(MTL::RenderCommandEncoder *encoder, const DofBlurPassParams &params, MetalRenderPipeline *pipeline, const std::shared_ptr< RenderTarget > &renderTarget, const std::vector< std::shared_ptr< MetalBindGroupFormat > > &bindGroupFormats, MTL::SamplerState *defaultSampler, MTL::DepthStencilState *defaultDepthStencilState)
Execute the DOF blur pass on the active render command encoder.
MTL::RenderPipelineState * get(const Primitive &primitive, const std::shared_ptr< VertexFormat > &vertexFormat0, const std::shared_ptr< VertexFormat > &vertexFormat1, int ibFormat, const std::shared_ptr< Shader > &shader, const std::shared_ptr< RenderTarget > &renderTarget, const std::vector< std::shared_ptr< MetalBindGroupFormat > > &bindGroupFormats, const std::shared_ptr< BlendState > &blendState, const std::shared_ptr< DepthState > &depthState, CullMode cullMode, bool stencilEnabled, const std::shared_ptr< StencilParameters > &stencilFront, const std::shared_ptr< StencilParameters > &stencilBack, const std::shared_ptr< VertexFormat > &instancingFormat=nullptr)
gpu::HardwareTexture * impl() const
Definition texture.h:101
std::shared_ptr< Shader > createShader(GraphicsDevice *graphicsDevice, const ShaderDefinition &definition, const std::string &sourceCode)
Definition shader.cpp:39
@ PRIMITIVE_TRIANGLES
Definition mesh.h:23
Describes how vertex and index data should be interpreted for a draw call.
Definition mesh.h:33
PrimitiveType type
Definition mesh.h:34