22#include "spdlog/spdlog.h"
31 constexpr const char* BLUR_SOURCE_HORIZONTAL = R
"(
32#include <metal_stdlib>
37struct ComposeVertexIn {
38 float3 position [[attribute(0)]];
39 float3 normal [[attribute(1)]];
40 float2 uv0 [[attribute(2)]];
41 float4 tangent [[attribute(3)]];
42 float2 uv1 [[attribute(4)]];
46 float4 position [[position]];
51 float2 sourceInvResolution;
57vertex BlurVarying blurVertex(ComposeVertexIn in [[stage_in]])
60 out.position = float4(in.position, 1.0);
65static inline float getLinearDepth(float rawDepth, float cameraNear, float cameraFar)
67 return (cameraNear * cameraFar) / (cameraFar - rawDepth * (cameraFar - cameraNear));
70static inline float bilateralWeight(float depth, float sampleDepth)
72 float diff = (sampleDepth - depth);
73 return max(0.0, 1.0 - diff * diff);
76static inline void tap(thread float& sum, thread float& totalWeight, float weight, float depth,
77 float2 position, texture2d<float> sourceTexture, depth2d<float> depthTexture,
78 sampler linearSampler, float cameraNear, float cameraFar)
80 float color = sourceTexture.sample(linearSampler, position).r;
81 float textureDepth = getLinearDepth(depthTexture.sample(linearSampler, position), cameraNear, cameraFar);
83 float bilateral = bilateralWeight(depth, textureDepth);
85 sum += color * bilateral;
86 totalWeight += bilateral;
89fragment float4 blurFragment(
90 BlurVarying in [[stage_in]],
91 texture2d<float> sourceTexture [[texture(0)]],
92 depth2d<float> depthTexture [[texture(1)]],
93 sampler linearSampler [[sampler(0)]],
94 constant BlurUniforms& uniforms [[buffer(5)]])
96 const float2 uv = clamp(in.uv, float2(0.0), float2(1.0));
98 // handle the center pixel separately because it doesn't participate in bilateral filtering
99 float depth = getLinearDepth(depthTexture.sample(linearSampler, uv), uniforms.cameraNear, uniforms.cameraFar);
100 float totalWeight = 1.0;
101 float color = sourceTexture.sample(linearSampler, uv).r;
102 float sum = color * totalWeight;
104 // Gaussian sigma: filterSize / 3 gives ~99.7% of the bell within the kernel
105 float sigma = max(float(uniforms.filterSize) / 3.0, 1.0);
106 float invSigma2 = 1.0 / (2.0 * sigma * sigma);
108 for (int i = -uniforms.filterSize; i <= uniforms.filterSize; i++) {
109 float weight = exp(-float(i * i) * invSigma2);
112 float2 offset = float2(float(i), 0.0) * uniforms.sourceInvResolution;
114 float2 offset = float2(0.0, float(i)) * uniforms.sourceInvResolution;
117 tap(sum, totalWeight, weight, depth, uv + offset, sourceTexture, depthTexture, linearSampler,
118 uniforms.cameraNear, uniforms.cameraFar);
121 float ao = sum / totalWeight;
122 return float4(ao, 0.0, 0.0, 1.0);
126 constexpr const char* BLUR_SOURCE_VERTICAL = R
"(
127#include <metal_stdlib>
128using namespace metal;
130struct ComposeVertexIn {
131 float3 position [[attribute(0)]];
132 float3 normal [[attribute(1)]];
133 float2 uv0 [[attribute(2)]];
134 float4 tangent [[attribute(3)]];
135 float2 uv1 [[attribute(4)]];
139 float4 position [[position]];
144 float2 sourceInvResolution;
150vertex BlurVarying blurVertex(ComposeVertexIn in [[stage_in]])
153 out.position = float4(in.position, 1.0);
158static inline float getLinearDepth(float rawDepth, float cameraNear, float cameraFar)
160 return (cameraNear * cameraFar) / (cameraFar - rawDepth * (cameraFar - cameraNear));
163static inline float bilateralWeight(float depth, float sampleDepth)
165 float diff = (sampleDepth - depth);
166 return max(0.0, 1.0 - diff * diff);
169static inline void tap(thread float& sum, thread float& totalWeight, float weight, float depth,
170 float2 position, texture2d<float> sourceTexture, depth2d<float> depthTexture,
171 sampler linearSampler, float cameraNear, float cameraFar)
173 float color = sourceTexture.sample(linearSampler, position).r;
174 float textureDepth = getLinearDepth(depthTexture.sample(linearSampler, position), cameraNear, cameraFar);
176 float bilateral = bilateralWeight(depth, textureDepth);
178 sum += color * bilateral;
179 totalWeight += bilateral;
182fragment float4 blurFragment(
183 BlurVarying in [[stage_in]],
184 texture2d<float> sourceTexture [[texture(0)]],
185 depth2d<float> depthTexture [[texture(1)]],
186 sampler linearSampler [[sampler(0)]],
187 constant BlurUniforms& uniforms [[buffer(5)]])
189 const float2 uv = clamp(in.uv, float2(0.0), float2(1.0));
191 float depth = getLinearDepth(depthTexture.sample(linearSampler, uv), uniforms.cameraNear, uniforms.cameraFar);
192 float totalWeight = 1.0;
193 float color = sourceTexture.sample(linearSampler, uv).r;
194 float sum = color * totalWeight;
196 // Gaussian sigma: filterSize / 3 gives ~99.7% of the bell within the kernel
197 float sigma = max(float(uniforms.filterSize) / 3.0, 1.0);
198 float invSigma2 = 1.0 / (2.0 * sigma * sigma);
200 for (int i = -uniforms.filterSize; i <= uniforms.filterSize; i++) {
201 float weight = exp(-float(i * i) * invSigma2);
203 // Vertical: offset along Y axis
204 float2 offset = float2(0.0, float(i)) * uniforms.sourceInvResolution;
206 tap(sum, totalWeight, weight, depth, uv + offset, sourceTexture, depthTexture, linearSampler,
207 uniforms.cameraNear, uniforms.cameraFar);
210 float ao = sum / totalWeight;
211 return float4(ao, 0.0, 0.0, 1.0);
217 : _device(device), _composePass(composePass), _horizontal(horizontal)
223 if (_depthStencilState) {
224 _depthStencilState->release();
225 _depthStencilState =
nullptr;
229 void MetalDepthAwareBlurPass::ensureResources()
235 _blendState && _depthState && _depthStencilState) {
240 ShaderDefinition definition;
241 definition.name = _horizontal ?
"DepthAwareBlurHorizontalPass" :
"DepthAwareBlurVerticalPass";
242 definition.vshader =
"blurVertex";
243 definition.fshader =
"blurFragment";
244 const char* source = _horizontal ? BLUR_SOURCE_HORIZONTAL : BLUR_SOURCE_VERTICAL;
249 _blendState = std::make_shared<BlendState>();
252 _depthState = std::make_shared<DepthState>();
254 if (!_depthStencilState && _device->raw()) {
255 auto* depthDesc = MTL::DepthStencilDescriptor::alloc()->init();
256 depthDesc->setDepthCompareFunction(MTL::CompareFunctionAlways);
257 depthDesc->setDepthWriteEnabled(
false);
258 _depthStencilState = _device->raw()->newDepthStencilState(depthDesc);
259 depthDesc->release();
266 const std::vector<std::shared_ptr<MetalBindGroupFormat>>& bindGroupFormats,
267 MTL::SamplerState* defaultSampler, MTL::DepthStencilState* defaultDepthStencilState)
274 if (!_shader || !_composePass->vertexBuffer() || !_composePass->vertexFormat() || !_blendState || !_depthState) {
275 spdlog::warn(
"[executeDepthAwareBlurPass] missing blur resources");
285 auto pipelineState = pipeline->
get(primitive, _composePass->vertexFormat(),
nullptr, -1, _shader, renderTarget,
287 if (!pipelineState) {
288 spdlog::warn(
"[executeDepthAwareBlurPass] failed to get pipeline state");
293 if (!vb || !vb->raw()) {
294 spdlog::warn(
"[executeDepthAwareBlurPass] missing vertex buffer");
298 encoder->setRenderPipelineState(pipelineState);
299 encoder->setCullMode(MTL::CullModeNone);
300 encoder->setDepthStencilState(_depthStencilState ? _depthStencilState : defaultDepthStencilState);
301 encoder->setVertexBuffer(vb->raw(), 0, 0);
306 encoder->setFragmentTexture(sourceHw ? sourceHw->raw() :
nullptr, 0);
307 encoder->setFragmentTexture(depthHw ? depthHw->raw() :
nullptr, 1);
308 if (defaultSampler) {
309 encoder->setFragmentSamplerState(defaultSampler, 0);
312 struct alignas(16) BlurUniforms
314 float sourceInvResolution[2];
325 encoder->setFragmentBytes(&uniforms,
sizeof(BlurUniforms), 5);
327 encoder->drawPrimitives(MTL::PrimitiveTypeTriangle,
static_cast<NS::UInteger
>(0),
328 static_cast<NS::UInteger
>(3));
329 _device->recordDrawCall();
gpu::HardwareTexture * impl() const
std::shared_ptr< Shader > createShader(GraphicsDevice *graphicsDevice, const ShaderDefinition &definition, const std::string &sourceCode)
float sourceInvResolutionY
float sourceInvResolutionX
Describes how vertex and index data should be interpreted for a draw call.