|
/* |
|
MIT License |
|
|
|
Copyright (c) 2026 Kyriakos Gavras |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and / or sell |
|
copies of the Software, and to permit persons to whom the Software is furnished |
|
to do so, subject to the following conditions : |
|
|
|
The above copyright notice and this permission notice shall be included in |
|
all copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS |
|
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE AUTHORS OR |
|
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER |
|
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN |
|
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
|
*/ |
|
|
|
#include <metal_stdlib> |
|
using namespace metal; |
|
|
|
float reduce2x2(float a, float b, float c, float d) { |
|
return max(max(a, b), max(c, d)); |
|
} |
|
|
|
struct SPDConstants { |
|
uint2 srcSize; |
|
uint mipCount; |
|
uint numWorkgroups; |
|
}; |
|
|
|
kernel void depth_pyramid_spd(texture2d<float, access::read> srcDepth [[texture(0)]], |
|
texture2d<float, access::read_write> pyramid [[texture(1)]], |
|
device atomic_uint& globalCounter [[buffer(0)]], |
|
constant SPDConstants& constants [[buffer(1)]], |
|
uint2 lid [[thread_position_in_threadgroup]], |
|
uint2 tgid [[threadgroup_position_in_grid]], |
|
uint lid_flat [[thread_index_in_threadgroup]]) { |
|
threadgroup float lds[32][32]; |
|
|
|
uint2 srcSize = constants.srcSize; |
|
uint2 baseCoord = tgid * 64 + lid * 4; |
|
|
|
// load 4x4 from source, copy to mip 0, reduce to 2x2 for mip 1 |
|
|
|
// row 0 |
|
float s00 = srcDepth.read(min(baseCoord + uint2(0, 0), srcSize - 1)).r; |
|
float s10 = srcDepth.read(min(baseCoord + uint2(1, 0), srcSize - 1)).r; |
|
float s20 = srcDepth.read(min(baseCoord + uint2(2, 0), srcSize - 1)).r; |
|
float s30 = srcDepth.read(min(baseCoord + uint2(3, 0), srcSize - 1)).r; |
|
|
|
// row 1 |
|
float s01 = srcDepth.read(min(baseCoord + uint2(0, 1), srcSize - 1)).r; |
|
float s11 = srcDepth.read(min(baseCoord + uint2(1, 1), srcSize - 1)).r; |
|
float s21 = srcDepth.read(min(baseCoord + uint2(2, 1), srcSize - 1)).r; |
|
float s31 = srcDepth.read(min(baseCoord + uint2(3, 1), srcSize - 1)).r; |
|
|
|
// row 2 |
|
float s02 = srcDepth.read(min(baseCoord + uint2(0, 2), srcSize - 1)).r; |
|
float s12 = srcDepth.read(min(baseCoord + uint2(1, 2), srcSize - 1)).r; |
|
float s22 = srcDepth.read(min(baseCoord + uint2(2, 2), srcSize - 1)).r; |
|
float s32 = srcDepth.read(min(baseCoord + uint2(3, 2), srcSize - 1)).r; |
|
|
|
// row 3 |
|
float s03 = srcDepth.read(min(baseCoord + uint2(0, 3), srcSize - 1)).r; |
|
float s13 = srcDepth.read(min(baseCoord + uint2(1, 3), srcSize - 1)).r; |
|
float s23 = srcDepth.read(min(baseCoord + uint2(2, 3), srcSize - 1)).r; |
|
float s33 = srcDepth.read(min(baseCoord + uint2(3, 3), srcSize - 1)).r; |
|
|
|
// write 4x4 block to mip 0 to get a full copy |
|
if (all(baseCoord + uint2(0, 0) < srcSize)) pyramid.write(float4(s00, 0, 0, 1), baseCoord + uint2(0, 0), 0); |
|
if (all(baseCoord + uint2(1, 0) < srcSize)) pyramid.write(float4(s10, 0, 0, 1), baseCoord + uint2(1, 0), 0); |
|
if (all(baseCoord + uint2(2, 0) < srcSize)) pyramid.write(float4(s20, 0, 0, 1), baseCoord + uint2(2, 0), 0); |
|
if (all(baseCoord + uint2(3, 0) < srcSize)) pyramid.write(float4(s30, 0, 0, 1), baseCoord + uint2(3, 0), 0); |
|
|
|
if (all(baseCoord + uint2(0, 1) < srcSize)) pyramid.write(float4(s01, 0, 0, 1), baseCoord + uint2(0, 1), 0); |
|
if (all(baseCoord + uint2(1, 1) < srcSize)) pyramid.write(float4(s11, 0, 0, 1), baseCoord + uint2(1, 1), 0); |
|
if (all(baseCoord + uint2(2, 1) < srcSize)) pyramid.write(float4(s21, 0, 0, 1), baseCoord + uint2(2, 1), 0); |
|
if (all(baseCoord + uint2(3, 1) < srcSize)) pyramid.write(float4(s31, 0, 0, 1), baseCoord + uint2(3, 1), 0); |
|
|
|
if (all(baseCoord + uint2(0, 2) < srcSize)) pyramid.write(float4(s02, 0, 0, 1), baseCoord + uint2(0, 2), 0); |
|
if (all(baseCoord + uint2(1, 2) < srcSize)) pyramid.write(float4(s12, 0, 0, 1), baseCoord + uint2(1, 2), 0); |
|
if (all(baseCoord + uint2(2, 2) < srcSize)) pyramid.write(float4(s22, 0, 0, 1), baseCoord + uint2(2, 2), 0); |
|
if (all(baseCoord + uint2(3, 2) < srcSize)) pyramid.write(float4(s32, 0, 0, 1), baseCoord + uint2(3, 2), 0); |
|
|
|
if (all(baseCoord + uint2(0, 3) < srcSize)) pyramid.write(float4(s03, 0, 0, 1), baseCoord + uint2(0, 3), 0); |
|
if (all(baseCoord + uint2(1, 3) < srcSize)) pyramid.write(float4(s13, 0, 0, 1), baseCoord + uint2(1, 3), 0); |
|
if (all(baseCoord + uint2(2, 3) < srcSize)) pyramid.write(float4(s23, 0, 0, 1), baseCoord + uint2(2, 3), 0); |
|
if (all(baseCoord + uint2(3, 3) < srcSize)) pyramid.write(float4(s33, 0, 0, 1), baseCoord + uint2(3, 3), 0); |
|
|
|
// reduce 4x4 to 2x2 for mip 1 |
|
float r0 = reduce2x2(s00, s10, s01, s11); |
|
float r1 = reduce2x2(s20, s30, s21, s31); |
|
float r2 = reduce2x2(s02, s12, s03, s13); |
|
float r3 = reduce2x2(s22, s32, s23, s33); |
|
|
|
uint2 ldsBase = lid * 2; |
|
lds[ldsBase.y + 0][ldsBase.x + 0] = r0; |
|
lds[ldsBase.y + 0][ldsBase.x + 1] = r1; |
|
lds[ldsBase.y + 1][ldsBase.x + 0] = r2; |
|
lds[ldsBase.y + 1][ldsBase.x + 1] = r3; |
|
|
|
// mip 1: 64x64 to 32x32 |
|
uint2 mip1Base = tgid * 32 + ldsBase; |
|
uint2 mip1Size = srcSize >> 1; |
|
|
|
if (all(mip1Base + uint2(0, 0) < mip1Size)) pyramid.write(float4(r0, 0, 0, 1), mip1Base + uint2(0, 0), 1); |
|
if (all(mip1Base + uint2(1, 0) < mip1Size)) pyramid.write(float4(r1, 0, 0, 1), mip1Base + uint2(1, 0), 1); |
|
if (all(mip1Base + uint2(0, 1) < mip1Size)) pyramid.write(float4(r2, 0, 0, 1), mip1Base + uint2(0, 1), 1); |
|
if (all(mip1Base + uint2(1, 1) < mip1Size)) pyramid.write(float4(r3, 0, 0, 1), mip1Base + uint2(1, 1), 1); |
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
// mip 2: 32x32 to 16x16 |
|
float reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0], |
|
lds[ldsBase.y + 0][ldsBase.x + 1], |
|
lds[ldsBase.y + 1][ldsBase.x + 0], |
|
lds[ldsBase.y + 1][ldsBase.x + 1]); |
|
|
|
lds[lid.y][lid.x] = reduced; |
|
|
|
uint2 mip2Coord = tgid * 16 + lid; |
|
uint2 mip2Size = srcSize >> 2; |
|
if (all(mip2Coord < mip2Size)) { |
|
pyramid.write(float4(reduced, 0, 0, 1), mip2Coord, 2); |
|
} |
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
// mip 3: 16x16 to 8x8 |
|
if (all(lid < 8)) { |
|
reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0], |
|
lds[ldsBase.y + 0][ldsBase.x + 1], |
|
lds[ldsBase.y + 1][ldsBase.x + 0], |
|
lds[ldsBase.y + 1][ldsBase.x + 1]); |
|
|
|
lds[lid.y][lid.x] = reduced; |
|
|
|
uint2 mip3Coord = tgid * 8 + lid; |
|
uint2 mip3Size = srcSize >> 3; |
|
if (all(mip3Coord < mip3Size)) { |
|
pyramid.write(float4(reduced, 0, 0, 1), mip3Coord, 3); |
|
} |
|
} |
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
// mip 4: 8x8 to 4x4 |
|
if (all(lid < 4)) { |
|
reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0], |
|
lds[ldsBase.y + 0][ldsBase.x + 1], |
|
lds[ldsBase.y + 1][ldsBase.x + 0], |
|
lds[ldsBase.y + 1][ldsBase.x + 1]); |
|
|
|
lds[lid.y][lid.x] = reduced; |
|
|
|
uint2 mip4Coord = tgid * 4 + lid; |
|
uint2 mip4Size = srcSize >> 4; |
|
if (all(mip4Coord < mip4Size)) { |
|
pyramid.write(float4(reduced, 0, 0, 1), mip4Coord, 4); |
|
} |
|
} |
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
// mip 5: 4x4 to 2x2 |
|
if (all(lid < 2)) { |
|
reduced = reduce2x2(lds[ldsBase.y + 0][ldsBase.x + 0], |
|
lds[ldsBase.y + 0][ldsBase.x + 1], |
|
lds[ldsBase.y + 1][ldsBase.x + 0], |
|
lds[ldsBase.y + 1][ldsBase.x + 1]); |
|
|
|
lds[lid.y][lid.x] = reduced; |
|
|
|
uint2 mip5Coord = tgid * 2 + lid; |
|
uint2 mip5Size = srcSize >> 5; |
|
if (all(mip5Coord < mip5Size)) { |
|
pyramid.write(float4(reduced, 0, 0, 1), mip5Coord, 5); |
|
} |
|
} |
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup); |
|
|
|
// mip 6: 2x2 to 1x1 |
|
if (lid_flat == 0) { |
|
reduced = reduce2x2(lds[0][0], lds[0][1], lds[1][0], lds[1][1]); |
|
|
|
uint2 mip6Coord = tgid; |
|
uint2 mip6Size = srcSize >> 6; |
|
if (all(mip6Coord < mip6Size)) { |
|
pyramid.write(float4(reduced, 0, 0, 1), mip6Coord, 6); |
|
} |
|
} |
|
|
|
// last threadgroup computes remaining mips |
|
threadgroup_barrier(mem_flags::mem_device | mem_flags::mem_threadgroup); |
|
|
|
if (lid_flat == 0) { |
|
uint completed = atomic_fetch_add_explicit(&globalCounter, 1, memory_order_relaxed) + 1; |
|
|
|
if (completed == constants.numWorkgroups) { |
|
atomic_store_explicit(&globalCounter, 0, memory_order_relaxed); |
|
|
|
for (uint mip = 7; mip < constants.mipCount; mip++) { |
|
uint2 dstSize = max(srcSize >> mip, uint2(1, 1)); |
|
uint2 srcMipSize = max(srcSize >> (mip - 1), uint2(1, 1)); |
|
|
|
for (uint y = 0; y < dstSize.y; y++) { |
|
for (uint x = 0; x < dstSize.x; x++) { |
|
uint2 readBase = uint2(x, y) * 2; |
|
float r0 = pyramid.read(min(readBase + uint2(0, 0), srcMipSize - 1), mip - 1).r; |
|
float r1 = pyramid.read(min(readBase + uint2(1, 0), srcMipSize - 1), mip - 1).r; |
|
float r2 = pyramid.read(min(readBase + uint2(0, 1), srcMipSize - 1), mip - 1).r; |
|
float r3 = pyramid.read(min(readBase + uint2(1, 1), srcMipSize - 1), mip - 1).r; |
|
pyramid.write(float4(reduce2x2(r0, r1, r2, r3), 0, 0, 1), uint2(x, y), mip); |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |