VisuTwin Canvas
C++ 3D Engine — Metal Backend
Loading...
Searching...
No Matches
metalTaaPass.cpp
Go to the documentation of this file.
1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2025-2026 Arnis Lektauers
3//
4// TAA (Temporal Anti-Aliasing) resolve pass implementation.
5// Extracted from MetalGraphicsDevice.
6//
7#include "metalTaaPass.h"
8
9#include "metalComposePass.h"
10#include "metalGraphicsDevice.h"
11#include "metalRenderPipeline.h"
12#include "metalTexture.h"
13#include "metalUtils.h"
14#include "metalVertexBuffer.h"
15#include "core/math/matrix4.h"
23#include "spdlog/spdlog.h"
24
25namespace visutwin::canvas
26{
27 namespace
28 {
29 constexpr const char* TAA_SOURCE = R"(
30#include <metal_stdlib>
31using namespace metal;
32
33struct ComposeVertexIn {
34 float3 position [[attribute(0)]];
35 float3 normal [[attribute(1)]];
36 float2 uv0 [[attribute(2)]];
37 float4 tangent [[attribute(3)]];
38 float2 uv1 [[attribute(4)]];
39};
40
41struct TaaVarying {
42 float4 position [[position]];
43 float2 uv;
44};
45
46struct TaaUniforms {
47 float4x4 viewProjectionPrevious;
48 float4x4 viewProjectionInverse;
49 float4 jitters;
50 float2 textureSize;
51 float4 cameraParams;
52 uint highQuality;
53 uint historyValid;
54};
55
56vertex TaaVarying taaVertex(ComposeVertexIn in [[stage_in]])
57{
58 TaaVarying out;
59 out.position = float4(in.position, 1.0);
60 out.uv = in.uv0;
61 return out;
62}
63
64static inline float linearizeDepth(float z, float4 cameraParams)
65{
66 if (cameraParams.w == 0.0) {
67 return (cameraParams.z * cameraParams.y) / (cameraParams.y + z * (cameraParams.z - cameraParams.y));
68 }
69 return cameraParams.z + z * (cameraParams.y - cameraParams.z);
70}
71
72static inline float delinearizeDepth(float linearDepth, float4 cameraParams)
73{
74 if (cameraParams.w == 0.0) {
75 return (cameraParams.y * (cameraParams.z - linearDepth)) /
76 (linearDepth * (cameraParams.z - cameraParams.y));
77 }
78 return (linearDepth - cameraParams.z) / (cameraParams.y - cameraParams.z);
79}
80
81static inline float2 reproject(float2 uv, float depth, constant TaaUniforms& uniforms)
82{
83 // DEVIATION: Metal depth buffer stores (ndcZ_gl + 1) / 2, undo to get OpenGL NDC Z
84 depth = depth * 2.0 - 1.0;
85
86 // DEVIATION: UV has Metal convention (V=0 at top), but the projection matrix uses
87 // OpenGL convention (NDC Y=+1 at top). Convert: ndcX = uv.x*2-1, ndcY = (1-uv.y)*2-1 = 1-2*uv.y.
88 float4 ndc = float4(uv.x * 2.0 - 1.0, 1.0 - 2.0 * uv.y, depth, 1.0);
89
90 // Remove jitter from the current frame
91 ndc.xy -= uniforms.jitters.xy;
92
93 float4 worldPosition = uniforms.viewProjectionInverse * ndc;
94 worldPosition /= worldPosition.w;
95
96 float4 screenPrevious = uniforms.viewProjectionPrevious * worldPosition;
97 // Convert back from NDC to Metal UV convention (flip Y back)
98 float2 prevNdc = screenPrevious.xy / screenPrevious.w;
99 return float2(prevNdc.x * 0.5 + 0.5, 0.5 - prevNdc.y * 0.5);
100}
101
102static inline float4 SampleTextureCatmullRom(
103 texture2d<float> tex, sampler linearSampler, float2 uv, float2 texSize)
104{
105 float2 samplePos = uv * texSize;
106 float2 texPos1 = floor(samplePos - 0.5) + 0.5;
107 float2 f = samplePos - texPos1;
108
109 float2 w0 = f * (-0.5 + f * (1.0 - 0.5 * f));
110 float2 w1 = 1.0 + f * f * (-2.5 + 1.5 * f);
111 float2 w2 = f * (0.5 + f * (2.0 - 1.5 * f));
112 float2 w3 = f * f * (-0.5 + 0.5 * f);
113
114 float2 w12 = w1 + w2;
115 float2 offset12 = w2 / (w1 + w2);
116
117 float2 texPos0 = (texPos1 - 1.0) / texSize;
118 float2 texPos3 = (texPos1 + 2.0) / texSize;
119 float2 texPos12 = (texPos1 + offset12) / texSize;
120
121 float4 result = float4(0.0);
122 result += tex.sample(linearSampler, float2(texPos0.x, texPos0.y), level(0.0)) * w0.x * w0.y;
123 result += tex.sample(linearSampler, float2(texPos12.x, texPos0.y), level(0.0)) * w12.x * w0.y;
124 result += tex.sample(linearSampler, float2(texPos3.x, texPos0.y), level(0.0)) * w3.x * w0.y;
125
126 result += tex.sample(linearSampler, float2(texPos0.x, texPos12.y), level(0.0)) * w0.x * w12.y;
127 result += tex.sample(linearSampler, float2(texPos12.x, texPos12.y), level(0.0)) * w12.x * w12.y;
128 result += tex.sample(linearSampler, float2(texPos3.x, texPos12.y), level(0.0)) * w3.x * w12.y;
129
130 result += tex.sample(linearSampler, float2(texPos0.x, texPos3.y), level(0.0)) * w0.x * w3.y;
131 result += tex.sample(linearSampler, float2(texPos12.x, texPos3.y), level(0.0)) * w12.x * w3.y;
132 result += tex.sample(linearSampler, float2(texPos3.x, texPos3.y), level(0.0)) * w3.x * w3.y;
133 return result;
134}
135
136static inline float4 colorClamp(texture2d<float> sourceTexture, sampler linearSampler, float2 uv, float4 historyColor, float2 textureSize)
137{
138 float3 minColor = float3(9999.0);
139 float3 maxColor = float3(-9999.0);
140 for (float x = -1.0; x <= 1.0; ++x) {
141 for (float y = -1.0; y <= 1.0; ++y) {
142 float3 color = sourceTexture.sample(linearSampler, uv + float2(x, y) / textureSize).rgb;
143 minColor = min(minColor, color);
144 maxColor = max(maxColor, color);
145 }
146 }
147
148 float3 clamped = clamp(historyColor.rgb, minColor, maxColor);
149 return float4(clamped, historyColor.a);
150}
151
152fragment float4 taaFragment(
153 TaaVarying in [[stage_in]],
154 texture2d<float> sourceTexture [[texture(0)]],
155 texture2d<float> historyTexture [[texture(1)]],
156 depth2d<float> depthTexture [[texture(2)]],
157 sampler linearSampler [[sampler(0)]],
158 constant TaaUniforms& uniforms [[buffer(5)]])
159{
160 // TAA resolve (GLSL to Metal).
161 const float2 uv = clamp(in.uv, float2(0.0), float2(1.0));
162
163 // Current frame color
164 const float4 srcColor = sourceTexture.sample(linearSampler, uv);
165
166 // If no valid history, just pass through current frame
167 if (uniforms.historyValid == 0u) {
168 return srcColor;
169 }
170
171 // DEVIATION: upstream uses getLinearScreenDepth()/delinearizeDepth() from
172 // screenDepthPS for the round-trip; the linearize->delinearize is an identity
173 // on the raw hardware depth. We skip the round-trip and use rawDepth directly
174 // since reproject() only needs the original viewport [0,1] depth.
175 float depth = depthTexture.sample(linearSampler, uv);
176
177 // Reproject: find where this pixel was in the previous frame
178 float2 historyUv = reproject(uv, depth, uniforms);
179
180 // Sample history: Catmull-Rom (high quality) or bilinear
181 float4 historyColor;
182 if (uniforms.highQuality != 0u) {
183 historyColor = SampleTextureCatmullRom(historyTexture, linearSampler, historyUv, uniforms.textureSize);
184 } else {
185 historyColor = historyTexture.sample(linearSampler, historyUv);
186 }
187
188 // Color clamping to handle disocclusion
189 float4 historyColorClamped = colorClamp(sourceTexture, linearSampler, uv, historyColor, uniforms.textureSize);
190
191 // Reject history samples that project outside the frame
192 float mixFactor = (historyUv.x < 0.0 || historyUv.x > 1.0 ||
193 historyUv.y < 0.0 || historyUv.y > 1.0) ? 1.0 : 0.05;
194
195 return mix(historyColorClamped, srcColor, mixFactor);
196}
197)";
198 }
199
201 : _device(device), _composePass(composePass)
202 {
203 }
204
206 {
207 if (_depthStencilState) {
208 _depthStencilState->release();
209 _depthStencilState = nullptr;
210 }
211 }
212
213 void MetalTaaPass::ensureResources()
214 {
215 if (_shader && _composePass->vertexBuffer() && _composePass->vertexFormat() &&
216 _blendState && _depthState && _depthStencilState) {
217 return;
218 }
219
220 if (!_shader) {
221 ShaderDefinition definition;
222 definition.name = "TaaResolvePass";
223 definition.vshader = "taaVertex";
224 definition.fshader = "taaFragment";
225 _shader = createShader(_device, definition, TAA_SOURCE);
226 }
227
228 if (!_blendState) {
229 _blendState = std::make_shared<BlendState>();
230 }
231 if (!_depthState) {
232 _depthState = std::make_shared<DepthState>();
233 }
234 if (!_depthStencilState && _device->raw()) {
235 auto* depthDesc = MTL::DepthStencilDescriptor::alloc()->init();
236 depthDesc->setDepthCompareFunction(MTL::CompareFunctionAlways);
237 depthDesc->setDepthWriteEnabled(false);
238 _depthStencilState = _device->raw()->newDepthStencilState(depthDesc);
239 depthDesc->release();
240 }
241 }
242
243 void MetalTaaPass::execute(MTL::RenderCommandEncoder* encoder,
244 Texture* sourceTexture, Texture* historyTexture, Texture* depthTexture,
245 const Matrix4& viewProjectionPrevious, const Matrix4& viewProjectionInverse,
246 const std::array<float, 4>& jitters, const std::array<float, 4>& cameraParams,
247 const bool highQuality, const bool historyValid,
248 MetalRenderPipeline* pipeline, const std::shared_ptr<RenderTarget>& renderTarget,
249 const std::vector<std::shared_ptr<MetalBindGroupFormat>>& bindGroupFormats,
250 MTL::SamplerState* defaultSampler, MTL::DepthStencilState* defaultDepthStencilState)
251 {
252 if (!encoder || !sourceTexture || !historyTexture || !depthTexture) {
253 return;
254 }
255
256 ensureResources();
257 if (!_shader || !_composePass->vertexBuffer() || !_composePass->vertexFormat() || !_blendState || !_depthState) {
258 spdlog::warn("[executeTAAPass] missing TAA resources");
259 return;
260 }
261
262 Primitive primitive;
263 primitive.type = PRIMITIVE_TRIANGLES;
264 primitive.base = 0;
265 primitive.count = 3;
266 primitive.indexed = false;
267
268 auto pipelineState = pipeline->get(primitive, _composePass->vertexFormat(), nullptr, -1, _shader, renderTarget,
269 bindGroupFormats, _blendState, _depthState, CullMode::CULLFACE_NONE, false, nullptr, nullptr);
270 if (!pipelineState) {
271 spdlog::warn("[executeTAAPass] failed to get pipeline state");
272 return;
273 }
274
275 auto* vb = dynamic_cast<MetalVertexBuffer*>(_composePass->vertexBuffer().get());
276 if (!vb || !vb->raw()) {
277 spdlog::warn("[executeTAAPass] missing vertex buffer");
278 return;
279 }
280
281 encoder->setRenderPipelineState(pipelineState);
282 encoder->setCullMode(MTL::CullModeNone);
283 encoder->setDepthStencilState(_depthStencilState ? _depthStencilState : defaultDepthStencilState);
284 encoder->setVertexBuffer(vb->raw(), 0, 0);
285
286 auto* sourceHw = dynamic_cast<gpu::MetalTexture*>(sourceTexture->impl());
287 auto* historyHw = dynamic_cast<gpu::MetalTexture*>(historyTexture->impl());
288 auto* depthHw = dynamic_cast<gpu::MetalTexture*>(depthTexture->impl());
289
290 encoder->setFragmentTexture(sourceHw ? sourceHw->raw() : nullptr, 0);
291 encoder->setFragmentTexture(historyHw ? historyHw->raw() : nullptr, 1);
292 encoder->setFragmentTexture(depthHw ? depthHw->raw() : nullptr, 2);
293 if (defaultSampler) {
294 encoder->setFragmentSamplerState(defaultSampler, 0);
295 }
296
297 struct alignas(16) TaaUniforms
298 {
299 simd::float4x4 viewProjectionPrevious;
300 simd::float4x4 viewProjectionInverse;
301 simd::float4 jitters;
302 simd::float2 textureSize;
303 simd::float4 cameraParams;
304 uint32_t highQuality;
305 uint32_t historyValid;
306 } uniforms{};
307 uniforms.viewProjectionPrevious = metal::toSimdMatrix(viewProjectionPrevious);
308 uniforms.viewProjectionInverse = metal::toSimdMatrix(viewProjectionInverse);
309
310 uniforms.jitters = simd::float4{jitters[0], jitters[1], jitters[2], jitters[3]};
311 uniforms.textureSize = simd::float2{
312 static_cast<float>(std::max(sourceTexture->width(), 1u)),
313 static_cast<float>(std::max(sourceTexture->height(), 1u))
314 };
315 uniforms.cameraParams = simd::float4{cameraParams[0], cameraParams[1], cameraParams[2], cameraParams[3]};
316 uniforms.highQuality = highQuality ? 1u : 0u;
317 uniforms.historyValid = historyValid ? 1u : 0u;
318 encoder->setFragmentBytes(&uniforms, sizeof(TaaUniforms), 5);
319
320 encoder->drawPrimitives(MTL::PrimitiveTypeTriangle, static_cast<NS::UInteger>(0),
321 static_cast<NS::UInteger>(3));
322 _device->recordDrawCall();
323 }
324}
std::shared_ptr< VertexFormat > vertexFormat() const
Shared vertex format (full-screen triangle, 14 floats per vertex).
std::shared_ptr< VertexBuffer > vertexBuffer() const
Shared vertex buffer (3-vertex full-screen triangle).
MTL::RenderPipelineState * get(const Primitive &primitive, const std::shared_ptr< VertexFormat > &vertexFormat0, const std::shared_ptr< VertexFormat > &vertexFormat1, int ibFormat, const std::shared_ptr< Shader > &shader, const std::shared_ptr< RenderTarget > &renderTarget, const std::vector< std::shared_ptr< MetalBindGroupFormat > > &bindGroupFormats, const std::shared_ptr< BlendState > &blendState, const std::shared_ptr< DepthState > &depthState, CullMode cullMode, bool stencilEnabled, const std::shared_ptr< StencilParameters > &stencilFront, const std::shared_ptr< StencilParameters > &stencilBack, const std::shared_ptr< VertexFormat > &instancingFormat=nullptr)
MetalTaaPass(MetalGraphicsDevice *device, MetalComposePass *composePass)
void execute(MTL::RenderCommandEncoder *encoder, Texture *sourceTexture, Texture *historyTexture, Texture *depthTexture, const Matrix4 &viewProjectionPrevious, const Matrix4 &viewProjectionInverse, const std::array< float, 4 > &jitters, const std::array< float, 4 > &cameraParams, bool highQuality, bool historyValid, MetalRenderPipeline *pipeline, const std::shared_ptr< RenderTarget > &renderTarget, const std::vector< std::shared_ptr< MetalBindGroupFormat > > &bindGroupFormats, MTL::SamplerState *defaultSampler, MTL::DepthStencilState *defaultDepthStencilState)
Execute the TAA resolve pass on the active render command encoder.
GPU texture resource supporting 2D, cubemap, volume, and array formats with mipmap management.
Definition texture.h:57
uint32_t width() const
Definition texture.h:63
uint32_t height() const
Definition texture.h:65
gpu::HardwareTexture * impl() const
Definition texture.h:101
simd::float4x4 toSimdMatrix(const Matrix4 &matrix)
Convert a column-major Matrix4 to a SIMD float4x4.
Definition metalUtils.h:20
std::shared_ptr< Shader > createShader(GraphicsDevice *graphicsDevice, const ShaderDefinition &definition, const std::string &sourceCode)
Definition shader.cpp:39
@ PRIMITIVE_TRIANGLES
Definition mesh.h:23
4x4 column-major transformation matrix with SIMD acceleration.
Definition matrix4.h:31
Describes how vertex and index data should be interpreted for a draw call.
Definition mesh.h:33
PrimitiveType type
Definition mesh.h:34