Skip to content

Instantly share code, notes, and snippets.

@sheaf
Created November 18, 2020 12:36
Show Gist options
  • Select an option

  • Save sheaf/f07d68f8bc5fe54f3e7230e7e5ad35bf to your computer and use it in GitHub Desktop.

Select an option

Save sheaf/f07d68f8bc5fe54f3e7230e7e5ad35bf to your computer and use it in GitHub Desktop.
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE NamedWildCards #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module FIR.Examples.Ising.Shaders where
-- base
import qualified Prelude
import Control.Monad
( sequence_ )
import Data.Proxy
( Proxy(..) )
import Data.Type.Bool
( If )
import Data.Type.Equality
( type (==) )
import Data.Maybe
( fromJust )
import GHC.TypeLits
( TypeError, ErrorMessage(..) )
import GHC.TypeNats
( KnownNat, type (*), Mod
, natVal
)
-- filepath
import System.FilePath
( (</>) )
-- vector-sized
import qualified Data.Vector.Sized as Vector
( fromList )
-- text-short
import Data.Text.Short
( ShortText )
-- fir
import FIR
import Math.Linear
-- fir-examples
import FIR.Examples.Paths
( shaderDir )
------------------------------------------------
-- Work sizes.
type Width = 1280 `WithDivisor` 2 `WithDivisor` LocalSizeX
type Height = 720 `WithDivisor` LocalSizeY
type SS = 2
type LocalSizeX = 32
type LocalSizeY = 18
width, height, ss, localSizeX, localSizeY :: Semiring a => a
width = fromInteger ( Prelude.toInteger $ natVal ( Proxy @Width ) )
height = fromInteger ( Prelude.toInteger $ natVal ( Proxy @Height ) )
ss = fromInteger ( Prelude.toInteger $ natVal ( Proxy @SS ) )
localSizeX = fromInteger ( Prelude.toInteger $ natVal ( Proxy @LocalSizeX ) )
localSizeY = fromInteger ( Prelude.toInteger $ natVal ( Proxy @LocalSizeY ) )
-- Step algorithm sizes:
-- Underlying lattice has size ( SS * Width ) * ( SS * Height ),
-- but we only handle half at a time (checkerboard pattern).
-- globalSizeX * localSizeX = SS * Width / 2
-- globalSizeY * localSizeY = SS * Height
-- Resolve algorithm size:
-- globalSizeX * localSizeX = Width
-- globalSizeY * localSizeY = Height
infixl 7 `WithDivisor`
type family n `WithDivisor` d where
n `WithDivisor` d =
If ( n `Mod` d == 0 )
n
( TypeError ( ShowType d :<>: Text " does not divide " :<>: ShowType n ) )
------------------------------------------------
-- Ising model simulation.
data Parity = Even | Odd
data SParity ( parity :: Parity ) where
SEven :: SParity Even
SOdd :: SParity Odd
type IsingParameters =
Struct
'[ "temperature" ':-> Float
, "interaction" ':-> Float
, "magneticField" ':-> Float
, "time" ':-> Float
]
type StepDefs ( parity :: Parity ) =
'[ "ubo" ':-> Uniform '[ DescriptorSet 0, Binding 0 ] IsingParameters
, "evenSpins" ':-> Image2D
'[ DescriptorSet ( If ( parity == Even ) 0 1 ) -- Assumes we are computing the "even" checkerboard step first.
, Binding 0
, NonWritable
]
( R32 F )
, "oddSpins" ':-> Image2D
'[ DescriptorSet 0
, Binding 1
, NonWritable
]
( R32 F )
, "outSpins" ':-> Image2D
'[ DescriptorSet 1
, Binding ( If ( parity == Even ) 0 1 )
, NonReadable
]
( R32 F )
, "localSpins" ':-> Workgroup '[] ( Array ( LocalSizeX * LocalSizeY ) Float )
, "main" ':-> EntryPoint '[ LocalSize LocalSizeX LocalSizeY 1 ] Compute
]
-- | Update all spins of lattice sites with the given checkerboard parity.
stepShader :: forall parity. _ => SParity parity -> Module ( StepDefs parity )
stepShader sParity = Module $ entryPoint @"main" @Compute do
isingParameters <- get @"ubo"
localIndex@( ~( Vec2 i_local_x i_local_y ) )
<- use @( Name "gl_LocalInvocationID" :.: Swizzle "xy" )
globalIndex
<- use @( Name "gl_GlobalInvocationID" :.: Swizzle "xy" )
-- Read spins at current index, and write spins for lattice sites of opposite checkerboard colouring into local storage.
let
localArrayIndex :: Code Word32
localArrayIndex = i_local_x + localSizeX * i_local_y
evenSpin <- imageRead @"evenSpins" globalIndex
oddSpin <- imageRead @"oddSpins" globalIndex
let
currentSpin, otherSpin :: Code Float
( currentSpin, otherSpin ) = case sParity of
SEven -> ( evenSpin, oddSpin )
SOdd -> ( oddSpin, evenSpin )
assign @( Name "localSpins" :.: AnIndex Word32 ) localArrayIndex otherSpin
memoryBarrier Workgroup ( MemorySemantics ( Lock [ Release ] ) [ WorkgroupMemory ] )
-- Simulation step: update spins of given parity by reading spins of opposite parity.
-- We can read the spins from local storage, except at the boundary of the workgroup.
neighbourSpins <- lookupNeighbourSpins sParity otherSpin localIndex globalIndex
newCurrentSpin <- updateSpin isingParameters currentSpin neighbourSpins
imageWrite @"outSpins" globalIndex newCurrentSpin
updateSpin :: Code IsingParameters -> Code Float -> Code ( V 4 Float ) -> Program _s _s ( Code Float )
updateSpin isingParameters s ( Vec4 u l r d ) = locally do
temperature <- def @"temperature" @R $ view @( Name "temperature" ) isingParameters
interaction <- def @"interaction" @R $ view @( Name "interaction" ) isingParameters
magneticField <- def @"magneticField" @R $ view @( Name "magneticField" ) isingParameters
time <- def @"time" @R $ view @( Name "time" ) isingParameters
-- Change in Hamiltonian resuling from flipping the current spin.
delta <- def @"delta" @R $ 2 * s * ( magneticField - interaction * ( u + l + r + d ) )
-- Metropolis–Hastings algorithm.
if delta <= 0
then pure ( negate s ) -- Flip if this would decrease total energy.
else do
-- Otherwise: flip if within Monte–Carlo acceptance threshold;
-- i.e. the temperature is high enough to flip a spin even though that increases total energy.
prob <- randomProbability time -- (time argument is only used as an entropy source)
if prob <= exp ( - delta / temperature )
then pure ( negate s )
else pure s
lookupNeighbourSpins
:: SParity parity
-> Code Float
-> Code ( V 2 Word32 ) -> Code ( V 2 Word32 )
-> Program _s _s ( Code ( V 4 Float ) )
lookupNeighbourSpins sParity otherSpin ( Vec2 i_local_x i_local_y ) ( Vec2 i_global_x i_global_y ) = do
-- Lattice site above: decrement y-index by 1.
spinU <-
if i_local_y > 0
then use @( Name "localSpins" :.: AnIndex Word32 ) ( i_local_x + localSizeX * ( i_local_y - 1 ) )
else do
let
globalIndex :: Code ( V 2 Word32 )
globalIndex =
if i_global_y > 0
then Vec2 i_global_x ( i_global_y - 1 )
else Vec2 i_global_x lastRow
useGlobalSpin globalIndex
-- Side lattice sites:
-- - if parity of row equals parity of current lattice site
-- * decrement x-index by 1 for left spin
-- * right spin is at current index (which we already know)
-- - otherwise
-- * left spin is at current index (which we already know)
-- * increment x-index by 1 for right spin
~( Vec2 spinL spinR ) <-
if ( case sParity of { SEven -> i_global_x `mod` 2 == 0 ; SOdd -> i_global_x `mod` 2 == 1 } :: Code Bool )
then do
spinL <-
if i_local_x > 0
then use @( Name "localSpins" :.: AnIndex Word32 ) ( i_local_x - 1 + localSizeX * i_local_y )
else do
let
globalIndex :: Code ( V 2 Word32 )
globalIndex =
if i_global_x > 0
then Vec2 ( i_global_x - 1 ) i_global_y
else Vec2 lastColumn i_global_y
useGlobalSpin globalIndex
pure $ Vec2 spinL otherSpin
else do
spinR <-
if i_local_x < localSizeX - 1
then use @( Name "localSpins" :.: AnIndex Word32 ) ( i_local_x + 1 + localSizeX * i_local_y )
else do
let
globalIndex :: Code ( V 2 Word32 )
globalIndex =
if i_global_x < lastColumn
then Vec2 ( i_global_x + 1 ) i_global_y
else Vec2 0 i_global_y
useGlobalSpin globalIndex
pure $ Vec2 otherSpin spinR
-- Lattice site below: increment y-index by 1.
spinD <-
if i_local_y < localSizeY - 1
then use @( Name "localSpins" :.: AnIndex Word32 ) ( i_local_x + localSizeX * ( i_local_y + 1 ) )
else do
let
globalIndex :: Code ( V 2 Word32 )
globalIndex =
if i_global_y < lastRow
then Vec2 i_global_x ( i_global_y + 1 )
else Vec2 i_global_x 0
useGlobalSpin globalIndex
pure ( Vec4 spinU spinL spinR spinD )
where
lastColumn, lastRow :: Code Word32
lastColumn = Lit . fromIntegral $ ( ( width * ss ) `div` 2 - 1 :: Int32 )
lastRow = Lit . fromIntegral $ ( height * ss - 1 :: Int32 )
useGlobalSpin :: Code ( V 2 Word32 ) -> Program _s _s ( Code Float )
useGlobalSpin = case sParity of
SEven -> imageRead @"oddSpins"
SOdd -> imageRead @"evenSpins"
randomProbability :: Code Float -> Program _s _s ( Code Float )
randomProbability entropy = locally do
s <- def @"s" @R $ sin ( 43758.5453123 * entropy )
pure ( s - floor s )
------------------------------------------------
-- Supersampling.
type ResolveDefs =
'[ "evenSpins" ':-> Image2D '[ DescriptorSet 1, Binding 0, NonWritable ] ( R32 F )
, "oddSpins" ':-> Image2D '[ DescriptorSet 1, Binding 1, NonWritable ] ( R32 F )
, "outputImage"':-> Image2D '[ DescriptorSet 1, Binding 2, NonReadable ] ( RGBA8 UNorm )
, "main" ':-> EntryPoint '[ LocalSize LocalSizeX LocalSizeY 1 ] Compute
]
resolveShader :: Module ResolveDefs
resolveShader = Module $ entryPoint @"main" @Compute do
globalIndex@( ~( Vec2 i_global_x i_global_y ) )
<- use @( Name "gl_GlobalInvocationID" :.: Swizzle "xy" )
-- Compute total spin over all lattice sites that correspond to the current pixel (supersampling).
_ <- def @"totalSpin" @RW @Float 0
supersamplingLoop \ ss_x ss_y -> locally do
checkerboardIndex <-
def @"checkerboardIndex" @R @( V 2 Word32 ) $
Vec2
( ( ss_x + ss * i_global_x ) `div` 2 )
( ss_y + ss * i_global_y )
spin <-
if ( ss_x + ss_y + ss * ( i_global_x + i_global_y ) ) `mod` 2 == 0 -- is the lattice site even?
then imageRead @"evenSpins" checkerboardIndex
else imageRead @"oddSpins" checkerboardIndex
modify @"totalSpin" ( + spin )
pure ()
totalSpin <- get @"totalSpin"
-- Colour mapping.
let
colour :: Code ( V 4 Float )
colour = gradient ( totalSpin / ( ss * ss ) ) ( Lit sunsetColours )
-- Write the result to output image.
imageWrite @"outputImage" globalIndex colour
supersamplingLoop
:: ( Code Word32 -> Code Word32 -> Program _s _s () )
-> Program _s _s ()
supersamplingLoop prog = locally do
_ <- def @"ssX" @RW @Word32 0
_ <- def @"ssY" @RW @Word32 0
while ( ( < ss ) <<$>> get @"ssX") do
ssX <- get @"ssX"
put @"ssY" 0
while ( ( < ss ) <<$>> get @"ssY" ) do
ssY <- get @"ssY"
embed ( prog ssX ssY )
put @"ssY" ( ssY + 1 )
put @"ssX" ( ssX + 1 )
pure ()
------------------------------------------------
-- Colour mapping.
-- | Gradient for input values between -1 and 1.
gradient :: forall n. KnownNat n
=> Code Float
-> Code (Array n (V 4 Float))
-> Code (V 4 Float)
gradient t colors
= ( (1-s) *^ ( view @(AnIndex _) i colors ) )
^+^ ( s *^ ( view @(AnIndex _) (i+1) colors ) )
where
t' :: Code Float
t' = 0.5 * (t+1)
n :: Code Float
n = Lit . fromIntegral $ knownValue @n
i :: Code Word32
i = floor ( (n-1) * t' )
s :: Code Float
s = (n-1) * t' - fromIntegral i
sunsetColours :: Array 9 (V 4 Float)
sunsetColours =
MkArray . fromJust . Vector.fromList $
[ V4 0 0 0 0
, V4 0.28 0.1 0.38 1
, V4 0.58 0.2 0.38 1
, V4 0.83 0.3 0.22 1
, V4 0.98 0.45 0.05 1
, V4 0.99 0.62 0.2 1
, V4 1 0.78 0.31 1
, V4 1 0.91 0.6 1
, V4 1 1 1 1
]
------------------------------------------------
-- compiling
evenStepPath, oddStepPath, resolvePath :: FilePath
evenStepPath = shaderDir </> "ising_even_comp.spv"
oddStepPath = shaderDir </> "ising_odd_comp.spv"
resolvePath = shaderDir </> "ising_resolve_comp.spv"
compileEvenStepShader, compileOddStepShader, compileResolveShader :: IO ( Either ShortText ModuleRequirements )
compileEvenStepShader = compileTo evenStepPath [] ( stepShader SEven )
compileOddStepShader = compileTo oddStepPath [] ( stepShader SOdd )
compileResolveShader = compileTo resolvePath [] resolveShader
compileAllShaders :: IO ()
compileAllShaders = sequence_
[ compileEvenStepShader
, compileOddStepShader
, compileResolveShader
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment