Created
November 18, 2020 12:36
-
-
Save sheaf/f07d68f8bc5fe54f3e7230e7e5ad35bf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| {-# LANGUAGE BlockArguments #-} | |
| {-# LANGUAGE DataKinds #-} | |
| {-# LANGUAGE FlexibleContexts #-} | |
| {-# LANGUAGE GADTs #-} | |
| {-# LANGUAGE NamedWildCards #-} | |
| {-# LANGUAGE PartialTypeSignatures #-} | |
| {-# LANGUAGE PolyKinds #-} | |
| {-# LANGUAGE RebindableSyntax #-} | |
| {-# LANGUAGE ScopedTypeVariables #-} | |
| {-# LANGUAGE TypeApplications #-} | |
| {-# LANGUAGE TypeFamilies #-} | |
| {-# LANGUAGE TypeOperators #-} | |
| {-# LANGUAGE UndecidableInstances #-} | |
| module FIR.Examples.Ising.Shaders where | |
| -- base | |
| import qualified Prelude | |
| import Control.Monad | |
| ( sequence_ ) | |
| import Data.Proxy | |
| ( Proxy(..) ) | |
| import Data.Type.Bool | |
| ( If ) | |
| import Data.Type.Equality | |
| ( type (==) ) | |
| import Data.Maybe | |
| ( fromJust ) | |
| import GHC.TypeLits | |
| ( TypeError, ErrorMessage(..) ) | |
| import GHC.TypeNats | |
| ( KnownNat, type (*), Mod | |
| , natVal | |
| ) | |
| -- filepath | |
| import System.FilePath | |
| ( (</>) ) | |
| -- vector-sized | |
| import qualified Data.Vector.Sized as Vector | |
| ( fromList ) | |
| -- text-short | |
| import Data.Text.Short | |
| ( ShortText ) | |
| -- fir | |
| import FIR | |
| import Math.Linear | |
| -- fir-examples | |
| import FIR.Examples.Paths | |
| ( shaderDir ) | |
| ------------------------------------------------ | |
| -- Work sizes. | |
| type Width = 1280 `WithDivisor` 2 `WithDivisor` LocalSizeX | |
| type Height = 720 `WithDivisor` LocalSizeY | |
| type SS = 2 | |
| type LocalSizeX = 32 | |
| type LocalSizeY = 18 | |
| width, height, ss, localSizeX, localSizeY :: Semiring a => a | |
| width = fromInteger ( Prelude.toInteger $ natVal ( Proxy @Width ) ) | |
| height = fromInteger ( Prelude.toInteger $ natVal ( Proxy @Height ) ) | |
| ss = fromInteger ( Prelude.toInteger $ natVal ( Proxy @SS ) ) | |
| localSizeX = fromInteger ( Prelude.toInteger $ natVal ( Proxy @LocalSizeX ) ) | |
| localSizeY = fromInteger ( Prelude.toInteger $ natVal ( Proxy @LocalSizeY ) ) | |
| -- Step algorithm sizes: | |
| -- Underlying lattice has size ( SS * Width ) * ( SS * Height ), | |
| -- but we only handle half at a time (checkerboard pattern). | |
| -- globalSizeX * localSizeX = SS * Width / 2 | |
| -- globalSizeY * localSizeY = SS * Height | |
| -- Resolve algorithm size: | |
| -- globalSizeX * localSizeX = Width | |
| -- globalSizeY * localSizeY = Height | |
| infixl 7 `WithDivisor` | |
| type family n `WithDivisor` d where | |
| n `WithDivisor` d = | |
| If ( n `Mod` d == 0 ) | |
| n | |
| ( TypeError ( ShowType d :<>: Text " does not divide " :<>: ShowType n ) ) | |
| ------------------------------------------------ | |
| -- Ising model simulation. | |
| data Parity = Even | Odd | |
| data SParity ( parity :: Parity ) where | |
| SEven :: SParity Even | |
| SOdd :: SParity Odd | |
| type IsingParameters = | |
| Struct | |
| '[ "temperature" ':-> Float | |
| , "interaction" ':-> Float | |
| , "magneticField" ':-> Float | |
| , "time" ':-> Float | |
| ] | |
| type StepDefs ( parity :: Parity ) = | |
| '[ "ubo" ':-> Uniform '[ DescriptorSet 0, Binding 0 ] IsingParameters | |
| , "evenSpins" ':-> Image2D | |
| '[ DescriptorSet ( If ( parity == Even ) 0 1 ) -- Assumes we are computing the "even" checkerboard step first. | |
| , Binding 0 | |
| , NonWritable | |
| ] | |
| ( R32 F ) | |
| , "oddSpins" ':-> Image2D | |
| '[ DescriptorSet 0 | |
| , Binding 1 | |
| , NonWritable | |
| ] | |
| ( R32 F ) | |
| , "outSpins" ':-> Image2D | |
| '[ DescriptorSet 1 | |
| , Binding ( If ( parity == Even ) 0 1 ) | |
| , NonReadable | |
| ] | |
| ( R32 F ) | |
| , "localSpins" ':-> Workgroup '[] ( Array ( LocalSizeX * LocalSizeY ) Float ) | |
| , "main" ':-> EntryPoint '[ LocalSize LocalSizeX LocalSizeY 1 ] Compute | |
| ] | |
| -- | Update all spins of lattice sites with the given checkerboard parity. | |
| stepShader :: forall parity. _ => SParity parity -> Module ( StepDefs parity ) | |
| stepShader sParity = Module $ entryPoint @"main" @Compute do | |
| isingParameters <- get @"ubo" | |
| localIndex@( ~( Vec2 i_local_x i_local_y ) ) | |
| <- use @( Name "gl_LocalInvocationID" :.: Swizzle "xy" ) | |
| globalIndex | |
| <- use @( Name "gl_GlobalInvocationID" :.: Swizzle "xy" ) | |
| -- Read spins at current index, and write spins for lattice sites of opposite checkerboard colouring into local storage. | |
| let | |
| localArrayIndex :: Code Word32 | |
| localArrayIndex = i_local_x + localSizeX * i_local_y | |
| evenSpin <- imageRead @"evenSpins" globalIndex | |
| oddSpin <- imageRead @"oddSpins" globalIndex | |
| let | |
| currentSpin, otherSpin :: Code Float | |
| ( currentSpin, otherSpin ) = case sParity of | |
| SEven -> ( evenSpin, oddSpin ) | |
| SOdd -> ( oddSpin, evenSpin ) | |
| assign @( Name "localSpins" :.: AnIndex Word32 ) localArrayIndex otherSpin | |
| memoryBarrier Workgroup ( MemorySemantics ( Lock [ Release ] ) [ WorkgroupMemory ] ) | |
| -- Simulation step: update spins of given parity by reading spins of opposite parity. | |
| -- We can read the spins from local storage, except at the boundary of the workgroup. | |
| neighbourSpins <- lookupNeighbourSpins sParity otherSpin localIndex globalIndex | |
| newCurrentSpin <- updateSpin isingParameters currentSpin neighbourSpins | |
| imageWrite @"outSpins" globalIndex newCurrentSpin | |
| updateSpin :: Code IsingParameters -> Code Float -> Code ( V 4 Float ) -> Program _s _s ( Code Float ) | |
| updateSpin isingParameters s ( Vec4 u l r d ) = locally do | |
| temperature <- def @"temperature" @R $ view @( Name "temperature" ) isingParameters | |
| interaction <- def @"interaction" @R $ view @( Name "interaction" ) isingParameters | |
| magneticField <- def @"magneticField" @R $ view @( Name "magneticField" ) isingParameters | |
| time <- def @"time" @R $ view @( Name "time" ) isingParameters | |
| -- Change in Hamiltonian resuling from flipping the current spin. | |
| delta <- def @"delta" @R $ 2 * s * ( magneticField - interaction * ( u + l + r + d ) ) | |
| -- Metropolis–Hastings algorithm. | |
| if delta <= 0 | |
| then pure ( negate s ) -- Flip if this would decrease total energy. | |
| else do | |
| -- Otherwise: flip if within Monte–Carlo acceptance threshold; | |
| -- i.e. the temperature is high enough to flip a spin even though that increases total energy. | |
| prob <- randomProbability time -- (time argument is only used as an entropy source) | |
| if prob <= exp ( - delta / temperature ) | |
| then pure ( negate s ) | |
| else pure s | |
| lookupNeighbourSpins | |
| :: SParity parity | |
| -> Code Float | |
| -> Code ( V 2 Word32 ) -> Code ( V 2 Word32 ) | |
| -> Program _s _s ( Code ( V 4 Float ) ) | |
| lookupNeighbourSpins sParity otherSpin ( Vec2 i_local_x i_local_y ) ( Vec2 i_global_x i_global_y ) = do | |
| -- Lattice site above: decrement y-index by 1. | |
| spinU <- | |
| if i_local_y > 0 | |
| then use @( Name "localSpins" :.: AnIndex Word32 ) ( i_local_x + localSizeX * ( i_local_y - 1 ) ) | |
| else do | |
| let | |
| globalIndex :: Code ( V 2 Word32 ) | |
| globalIndex = | |
| if i_global_y > 0 | |
| then Vec2 i_global_x ( i_global_y - 1 ) | |
| else Vec2 i_global_x lastRow | |
| useGlobalSpin globalIndex | |
| -- Side lattice sites: | |
| -- - if parity of row equals parity of current lattice site | |
| -- * decrement x-index by 1 for left spin | |
| -- * right spin is at current index (which we already know) | |
| -- - otherwise | |
| -- * left spin is at current index (which we already know) | |
| -- * increment x-index by 1 for right spin | |
| ~( Vec2 spinL spinR ) <- | |
| if ( case sParity of { SEven -> i_global_x `mod` 2 == 0 ; SOdd -> i_global_x `mod` 2 == 1 } :: Code Bool ) | |
| then do | |
| spinL <- | |
| if i_local_x > 0 | |
| then use @( Name "localSpins" :.: AnIndex Word32 ) ( i_local_x - 1 + localSizeX * i_local_y ) | |
| else do | |
| let | |
| globalIndex :: Code ( V 2 Word32 ) | |
| globalIndex = | |
| if i_global_x > 0 | |
| then Vec2 ( i_global_x - 1 ) i_global_y | |
| else Vec2 lastColumn i_global_y | |
| useGlobalSpin globalIndex | |
| pure $ Vec2 spinL otherSpin | |
| else do | |
| spinR <- | |
| if i_local_x < localSizeX - 1 | |
| then use @( Name "localSpins" :.: AnIndex Word32 ) ( i_local_x + 1 + localSizeX * i_local_y ) | |
| else do | |
| let | |
| globalIndex :: Code ( V 2 Word32 ) | |
| globalIndex = | |
| if i_global_x < lastColumn | |
| then Vec2 ( i_global_x + 1 ) i_global_y | |
| else Vec2 0 i_global_y | |
| useGlobalSpin globalIndex | |
| pure $ Vec2 otherSpin spinR | |
| -- Lattice site below: increment y-index by 1. | |
| spinD <- | |
| if i_local_y < localSizeY - 1 | |
| then use @( Name "localSpins" :.: AnIndex Word32 ) ( i_local_x + localSizeX * ( i_local_y + 1 ) ) | |
| else do | |
| let | |
| globalIndex :: Code ( V 2 Word32 ) | |
| globalIndex = | |
| if i_global_y < lastRow | |
| then Vec2 i_global_x ( i_global_y + 1 ) | |
| else Vec2 i_global_x 0 | |
| useGlobalSpin globalIndex | |
| pure ( Vec4 spinU spinL spinR spinD ) | |
| where | |
| lastColumn, lastRow :: Code Word32 | |
| lastColumn = Lit . fromIntegral $ ( ( width * ss ) `div` 2 - 1 :: Int32 ) | |
| lastRow = Lit . fromIntegral $ ( height * ss - 1 :: Int32 ) | |
| useGlobalSpin :: Code ( V 2 Word32 ) -> Program _s _s ( Code Float ) | |
| useGlobalSpin = case sParity of | |
| SEven -> imageRead @"oddSpins" | |
| SOdd -> imageRead @"evenSpins" | |
| randomProbability :: Code Float -> Program _s _s ( Code Float ) | |
| randomProbability entropy = locally do | |
| s <- def @"s" @R $ sin ( 43758.5453123 * entropy ) | |
| pure ( s - floor s ) | |
| ------------------------------------------------ | |
| -- Supersampling. | |
| type ResolveDefs = | |
| '[ "evenSpins" ':-> Image2D '[ DescriptorSet 1, Binding 0, NonWritable ] ( R32 F ) | |
| , "oddSpins" ':-> Image2D '[ DescriptorSet 1, Binding 1, NonWritable ] ( R32 F ) | |
| , "outputImage"':-> Image2D '[ DescriptorSet 1, Binding 2, NonReadable ] ( RGBA8 UNorm ) | |
| , "main" ':-> EntryPoint '[ LocalSize LocalSizeX LocalSizeY 1 ] Compute | |
| ] | |
| resolveShader :: Module ResolveDefs | |
| resolveShader = Module $ entryPoint @"main" @Compute do | |
| globalIndex@( ~( Vec2 i_global_x i_global_y ) ) | |
| <- use @( Name "gl_GlobalInvocationID" :.: Swizzle "xy" ) | |
| -- Compute total spin over all lattice sites that correspond to the current pixel (supersampling). | |
| _ <- def @"totalSpin" @RW @Float 0 | |
| supersamplingLoop \ ss_x ss_y -> locally do | |
| checkerboardIndex <- | |
| def @"checkerboardIndex" @R @( V 2 Word32 ) $ | |
| Vec2 | |
| ( ( ss_x + ss * i_global_x ) `div` 2 ) | |
| ( ss_y + ss * i_global_y ) | |
| spin <- | |
| if ( ss_x + ss_y + ss * ( i_global_x + i_global_y ) ) `mod` 2 == 0 -- is the lattice site even? | |
| then imageRead @"evenSpins" checkerboardIndex | |
| else imageRead @"oddSpins" checkerboardIndex | |
| modify @"totalSpin" ( + spin ) | |
| pure () | |
| totalSpin <- get @"totalSpin" | |
| -- Colour mapping. | |
| let | |
| colour :: Code ( V 4 Float ) | |
| colour = gradient ( totalSpin / ( ss * ss ) ) ( Lit sunsetColours ) | |
| -- Write the result to output image. | |
| imageWrite @"outputImage" globalIndex colour | |
| supersamplingLoop | |
| :: ( Code Word32 -> Code Word32 -> Program _s _s () ) | |
| -> Program _s _s () | |
| supersamplingLoop prog = locally do | |
| _ <- def @"ssX" @RW @Word32 0 | |
| _ <- def @"ssY" @RW @Word32 0 | |
| while ( ( < ss ) <<$>> get @"ssX") do | |
| ssX <- get @"ssX" | |
| put @"ssY" 0 | |
| while ( ( < ss ) <<$>> get @"ssY" ) do | |
| ssY <- get @"ssY" | |
| embed ( prog ssX ssY ) | |
| put @"ssY" ( ssY + 1 ) | |
| put @"ssX" ( ssX + 1 ) | |
| pure () | |
| ------------------------------------------------ | |
| -- Colour mapping. | |
| -- | Gradient for input values between -1 and 1. | |
| gradient :: forall n. KnownNat n | |
| => Code Float | |
| -> Code (Array n (V 4 Float)) | |
| -> Code (V 4 Float) | |
| gradient t colors | |
| = ( (1-s) *^ ( view @(AnIndex _) i colors ) ) | |
| ^+^ ( s *^ ( view @(AnIndex _) (i+1) colors ) ) | |
| where | |
| t' :: Code Float | |
| t' = 0.5 * (t+1) | |
| n :: Code Float | |
| n = Lit . fromIntegral $ knownValue @n | |
| i :: Code Word32 | |
| i = floor ( (n-1) * t' ) | |
| s :: Code Float | |
| s = (n-1) * t' - fromIntegral i | |
| sunsetColours :: Array 9 (V 4 Float) | |
| sunsetColours = | |
| MkArray . fromJust . Vector.fromList $ | |
| [ V4 0 0 0 0 | |
| , V4 0.28 0.1 0.38 1 | |
| , V4 0.58 0.2 0.38 1 | |
| , V4 0.83 0.3 0.22 1 | |
| , V4 0.98 0.45 0.05 1 | |
| , V4 0.99 0.62 0.2 1 | |
| , V4 1 0.78 0.31 1 | |
| , V4 1 0.91 0.6 1 | |
| , V4 1 1 1 1 | |
| ] | |
| ------------------------------------------------ | |
| -- compiling | |
| evenStepPath, oddStepPath, resolvePath :: FilePath | |
| evenStepPath = shaderDir </> "ising_even_comp.spv" | |
| oddStepPath = shaderDir </> "ising_odd_comp.spv" | |
| resolvePath = shaderDir </> "ising_resolve_comp.spv" | |
| compileEvenStepShader, compileOddStepShader, compileResolveShader :: IO ( Either ShortText ModuleRequirements ) | |
| compileEvenStepShader = compileTo evenStepPath [] ( stepShader SEven ) | |
| compileOddStepShader = compileTo oddStepPath [] ( stepShader SOdd ) | |
| compileResolveShader = compileTo resolvePath [] resolveShader | |
| compileAllShaders :: IO () | |
| compileAllShaders = sequence_ | |
| [ compileEvenStepShader | |
| , compileOddStepShader | |
| , compileResolveShader | |
| ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment