Skip to content

Instantly share code, notes, and snippets.

@Qirias
Created February 12, 2026 21:07
Show Gist options
  • Select an option

  • Save Qirias/c66904fc276a2044d44c843468842c1d to your computer and use it in GitHub Desktop.

Select an option

Save Qirias/c66904fc276a2044d44c843468842c1d to your computer and use it in GitHub Desktop.
Depth Pyramid - Single Pass Downsampler in Metal
void depth_pyramid(MTL::CommandBuffer *commandBuffer) {
uint32_t zero = 0;
memcpy(_spdAtomicBuffer->contents(), &zero, sizeof(uint32_t));
MTL::ComputeCommandEncoder* encoder = commandBuffer->computeCommandEncoder();
encoder->setLabel(NS::String::string("Depth Pyramid SPD", NS::ASCIIStringEncoding));
encoder->setComputePipelineState(_depthPyramidPSO);
encoder->setTexture(depthTexture, 0);
encoder->setTexture(_depthPyramidTexture, 1);
encoder->setBuffer(_spdAtomicBuffer, 0, 0);
struct SPDConstants {
uint32_t srcSize[2];
uint32_t mipCount;
uint32_t numWorkgroups;
};
uint32_t workgroupsX = (_depthWidth + 64 - 1) / 64;
uint32_t workgroupsY = (_depthHeight + 64 - 1) / 64;
SPDConstants constants;
constants.srcSize[0] = _depthWidth;
constants.srcSize[1] = _depthHeight;
constants.mipCount = _mipCount;
constants.numWorkgroups = workgroupsX * workgroupsY;
encoder->setBytes(&constants, sizeof(constants), 1);
MTL::Size threadgroupSize = MTL::Size::Make(16, 16, 1);
MTL::Size threadgroupCount = MTL::Size::Make(workgroupsX, workgroupsY, 1);
encoder->dispatchThreadgroups(threadgroupCount, threadgroupSize);
encoder->endEncoding();
}
/*
MIT License
Copyright (c) 2026 Kyriakos Gavras
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
copies of the Software, and to permit persons to whom the Software is furnished
to do so, subject to the following conditions :
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <metal_stdlib>
using namespace metal;
float reduce2x2(float a, float b, float c, float d) {
return max(max(a, b), max(c, d));
}
struct SPDConstants {
uint2 srcSize;
uint mipCount;
uint numWorkgroups;
};
kernel void depth_pyramid_spd(texture2d<float, access::read> srcDepth [[texture(0)]],
texture2d<float, access::read_write> pyramid [[texture(1)]],
device atomic_uint& globalCounter [[buffer(0)]],
constant SPDConstants& constants [[buffer(1)]],
uint2 lid [[thread_position_in_threadgroup]],
uint2 tgid [[threadgroup_position_in_grid]],
uint lid_flat [[thread_index_in_threadgroup]]) {
threadgroup float lds[32][32];
uint2 srcSize = constants.srcSize;
uint2 baseCoord = tgid * 64 + lid * 4;
// load 4x4 from source, copy to mip 0, reduce to 2x2 for mip 1
// row 0
float s00 = srcDepth.read(min(baseCoord + uint2(0, 0), srcSize - 1)).r;
float s10 = srcDepth.read(min(baseCoord + uint2(1, 0), srcSize - 1)).r;
float s20 = srcDepth.read(min(baseCoord + uint2(2, 0), srcSize - 1)).r;
float s30 = srcDepth.read(min(baseCoord + uint2(3, 0), srcSize - 1)).r;
// row 1
float s01 = srcDepth.read(min(baseCoord + uint2(0, 1), srcSize - 1)).r;
float s11 = srcDepth.read(min(baseCoord + uint2(1, 1), srcSize - 1)).r;
float s21 = srcDepth.read(min(baseCoord + uint2(2, 1), srcSize - 1)).r;
float s31 = srcDepth.read(min(baseCoord + uint2(3, 1), srcSize - 1)).r;
// row 2
float s02 = srcDepth.read(min(baseCoord + uint2(0, 2), srcSize - 1)).r;
float s12 = srcDepth.read(min(baseCoord + uint2(1, 2), srcSize - 1)).r;
float s22 = srcDepth.read(min(baseCoord + uint2(2, 2), srcSize - 1)).r;
float s32 = srcDepth.read(min(baseCoord + uint2(3, 2), srcSize - 1)).r;
// row 3
float s03 = srcDepth.read(min(baseCoord + uint2(0, 3), srcSize - 1)).r;
float s13 = srcDepth.read(min(baseCoord + uint2(1, 3), srcSize - 1)).r;
float s23 = srcDepth.read(min(baseCoord + uint2(2, 3), srcSize - 1)).r;
float s33 = srcDepth.read(min(baseCoord + uint2(3, 3), srcSize - 1)).r;
// write 4x4 block to mip 0 to get a full copy
if (all(baseCoord + uint2(0, 0) < srcSize)) pyramid.write(float4(s00, 0, 0, 1), baseCoord + uint2(0, 0), 0);
if (all(baseCoord + uint2(1, 0) < srcSize)) pyramid.write(float4(s10, 0, 0, 1), baseCoord + uint2(1, 0), 0);
if (all(baseCoord + uint2(2, 0) < srcSize)) pyramid.write(float4(s20, 0, 0, 1), baseCoord + uint2(2, 0), 0);
if (all(baseCoord + uint2(3, 0) < srcSize)) pyramid.write(float4(s30, 0, 0, 1), baseCoord + uint2(3, 0), 0);
if (all(baseCoord + uint2(0, 1) < srcSize)) pyramid.write(float4(s01, 0, 0, 1), baseCoord + uint2(0, 1), 0);
if (all(baseCoord + uint2(1, 1) < srcSize)) pyramid.write(float4(s11, 0, 0, 1), baseCoord + uint2(1, 1), 0);
if (all(baseCoord + uint2(2, 1) < srcSize)) pyramid.write(float4(s21, 0, 0, 1), baseCoord + uint2(2, 1), 0);
if (all(baseCoord + uint2(3, 1) < srcSize)) pyramid.write(float4(s31, 0, 0, 1), baseCoord + uint2(3, 1), 0);
if (all(baseCoord + uint2(0, 2) < srcSize)) pyramid.write(float4(s02, 0, 0, 1), baseCoord + uint2(0, 2), 0);
if (all(baseCoord + uint2(1, 2) < srcSize)) pyramid.write(float4(s12, 0, 0, 1), baseCoord + uint2(1, 2), 0);
if (all(baseCoord + uint2(2, 2) < srcSize)) pyramid.write(float4(s22, 0, 0, 1), baseCoord + uint2(2, 2), 0);
if (all(baseCoord + uint2(3, 2) < srcSize)) pyramid.write(float4(s32, 0, 0, 1), baseCoord + uint2(3, 2), 0);
if (all(baseCoord + uint2(0, 3) < srcSize)) pyramid.write(float4(s03, 0, 0, 1), baseCoord + uint2(0, 3), 0);
if (all(baseCoord + uint2(1, 3) < srcSize)) pyramid.write(float4(s13, 0, 0, 1), baseCoord + uint2(1, 3), 0);
if (all(baseCoord + uint2(2, 3) < srcSize)) pyramid.write(float4(s23, 0, 0, 1), baseCoord + uint2(2, 3), 0);
if (all(baseCoord + uint2(3, 3) < srcSize)) pyramid.write(float4(s33, 0, 0, 1), baseCoord + uint2(3, 3), 0);
// reduce 4x4 to 2x2 for mip 1
float r0 = reduce2x2(s00, s10, s01, s11);
float r1 = reduce2x2(s20, s30, s21, s31);
float r2 = reduce2x2(s02, s12, s03, s13);
float r3 = reduce2x2(s22, s32, s23, s33);
uint2 ldsBase = lid * 2;
lds[ldsBase.y + 0][ldsBase.x + 0] = r0;
lds[ldsBase.y + 0][ldsBase.x + 1] = r1;
lds[ldsBase.y + 1][ldsBase.x + 0] = r2;
lds[ldsBase.y + 1][ldsBase.x + 1] = r3;
// mip 1: 64x64 to 32x32
uint2 mip1Base = tgid * 32 + ldsBase;
uint2 mip1Size = srcSize >> 1;
if (all(mip1Base + uint2(0, 0) < mip1Size)) pyramid.write(float4(r0, 0, 0, 1), mip1Base + uint2(0, 0), 1);
if (all(mip1Base + uint2(1, 0) < mip1Size)) pyramid.write(float4(r1, 0, 0, 1), mip1Base + uint2(1, 0), 1);
if (all(mip1Base + uint2(0, 1) < mip1Size)) pyramid.write(float4(r2, 0, 0, 1), mip1Base + uint2(0, 1), 1);
if (all(mip1Base + uint2(1, 1) < mip1Size)) pyramid.write(float4(r3, 0, 0, 1), mip1Base + uint2(1, 1), 1);
threadgroup_barrier(mem_flags::mem_threadgroup);
// mip 2: 32x32 to 16x16
float reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0],
lds[ldsBase.y + 0][ldsBase.x + 1],
lds[ldsBase.y + 1][ldsBase.x + 0],
lds[ldsBase.y + 1][ldsBase.x + 1]);
lds[lid.y][lid.x] = reduced;
uint2 mip2Coord = tgid * 16 + lid;
uint2 mip2Size = srcSize >> 2;
if (all(mip2Coord < mip2Size)) {
pyramid.write(float4(reduced, 0, 0, 1), mip2Coord, 2);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// mip 3: 16x16 to 8x8
if (all(lid < 8)) {
reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0],
lds[ldsBase.y + 0][ldsBase.x + 1],
lds[ldsBase.y + 1][ldsBase.x + 0],
lds[ldsBase.y + 1][ldsBase.x + 1]);
lds[lid.y][lid.x] = reduced;
uint2 mip3Coord = tgid * 8 + lid;
uint2 mip3Size = srcSize >> 3;
if (all(mip3Coord < mip3Size)) {
pyramid.write(float4(reduced, 0, 0, 1), mip3Coord, 3);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// mip 4: 8x8 to 4x4
if (all(lid < 4)) {
reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0],
lds[ldsBase.y + 0][ldsBase.x + 1],
lds[ldsBase.y + 1][ldsBase.x + 0],
lds[ldsBase.y + 1][ldsBase.x + 1]);
lds[lid.y][lid.x] = reduced;
uint2 mip4Coord = tgid * 4 + lid;
uint2 mip4Size = srcSize >> 4;
if (all(mip4Coord < mip4Size)) {
pyramid.write(float4(reduced, 0, 0, 1), mip4Coord, 4);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// mip 5: 4x4 to 2x2
if (all(lid < 2)) {
reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0],
lds[ldsBase.y + 0][ldsBase.x + 1],
lds[ldsBase.y + 1][ldsBase.x + 0],
lds[ldsBase.y + 1][ldsBase.x + 1]);
lds[lid.y][lid.x] = reduced;
uint2 mip5Coord = tgid * 2 + lid;
uint2 mip5Size = srcSize >> 5;
if (all(mip5Coord < mip5Size)) {
pyramid.write(float4(reduced, 0, 0, 1), mip5Coord, 5);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// mip 6: 2x2 to 1x1
if (lid_flat == 0) {
reduced = reduce2x2(lds[0][0], lds[0][1], lds[1][0], lds[1][1]);
uint2 mip6Coord = tgid;
uint2 mip6Size = srcSize >> 6;
if (all(mip6Coord < mip6Size)) {
pyramid.write(float4(reduced, 0, 0, 1), mip6Coord, 6);
}
}
// last threadgroup computes remaining mips
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup);
if (lid_flat == 0) {
uint completed = atomic_fetch_add_explicit(&globalCounter, 1, memory_order_relaxed) + 1;
if (completed == constants.numWorkgroups) {
atomic_store_explicit(&globalCounter, 0, memory_order_relaxed);
for (uint mip = 7; mip < constants.mipCount; mip++) {
uint2 dstSize = max(srcSize >> mip, uint2(1, 1));
uint2 srcMipSize = max(srcSize >> (mip - 1), uint2(1, 1));
for (uint y = 0; y < dstSize.y; y++) {
for (uint x = 0; x < dstSize.x; x++) {
uint2 readBase = uint2(x, y) * 2;
float r0 = pyramid.read(min(readBase + uint2(0, 0), srcMipSize - 1), mip - 1).r;
float r1 = pyramid.read(min(readBase + uint2(1, 0), srcMipSize - 1), mip - 1).r;
float r2 = pyramid.read(min(readBase + uint2(0, 1), srcMipSize - 1), mip - 1).r;
float r3 = pyramid.read(min(readBase + uint2(1, 1), srcMipSize - 1), mip - 1).r;
pyramid.write(float4(reduce2x2(r0, r1, r2, r3), 0, 0, 1), uint2(x, y), mip);
}
}
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment