Skip to content

Instantly share code, notes, and snippets.

@miguelmartin75
Created August 7, 2025 20:08
Show Gist options
  • Select an option

  • Save miguelmartin75/6de2a6aed228e456410cadb510f4845f to your computer and use it in GitHub Desktop.

Select an option

Save miguelmartin75/6de2a6aed228e456410cadb510f4845f to your computer and use it in GitHub Desktop.
import std/[strformat, atomics, terminal, options]
export options
type
ExceptionInfo* = object
name*: string
msg*: string
stacktrace*: string
ParResult*[T] = object
value*: T
exception*: Option[ExceptionInfo]
ParMap*[I, O] = object
idx: seq[Atomic[int]]
input: ptr UncheckedArray[I]
n: int
nthreads: int
output: ptr UncheckedArray[O]
stole: seq[int]
proc stealWork(ctx: ptr ParMap): int =
var
minI = 0
minV = 0
for i in 0 ..< ctx.nthreads:
let v = ctx.idx[i].load
if minV >= v:
minV = v
minI = i
let k = ctx.idx[minI].fetchAdd(ctx.nthreads)
return if k >= ctx.n: -1 else: k
proc parMapWorker[I, O](
args: tuple[ctx: ptr ParMap[I, O], id: int, fn: proc(x: I): O {.gcsafe.}]
) {.gcsafe.} =
let
(ctx, id, fn) = args
nthreads = ctx[].nthreads
n = ctx.n
N = n div nthreads
var idx = ctx[].idx[id].addr
while true:
let i = idx[].fetchAdd(nthreads)
if i >= n:
break
let x = ctx[].input[i]
ctx[].output[i] = fn(x)
if id == 0 and (i div nthreads) mod 3 == 0:
if i > 0:
stderr.cursorUp ctx.nthreads
for j in 0 ..< ctx.nthreads:
let
ix = ctx.idx[j].load(moRelaxed)
ii = (ix - j) div nthreads
stderr.eraseLine()
stderr.write &"{j}: {100*ii/N:.1f}% - {ii} / {N} - idx={ix}, stole: {ctx.stole[j]}\n"
while true:
var i = ctx.stealWork()
if i < 0:
break
let x = ctx[].input[i]
ctx[].output[i] = fn(x)
ctx.stole[id] += 1
if id == 0:
stderr.cursorUp ctx.nthreads
for j in 0 ..< ctx.nthreads:
let
ix = ctx.idx[j].load(moRelaxed)
ii = (ix - j) div nthreads
stderr.eraseLine()
stderr.write &"{j}: {100*ii/N:.1f}% - {ii} / {N} - idx={ix}, stole: {ctx.stole[j]}\n"
proc parMap*[I, O](
data: openArray[I], nthreads: int, fn: proc(x: I): O {.gcsafe.}
): seq[ParResult[O]] =
if data.len == 0:
return
proc fnResult(x: I): ParResult[O] =
try:
result.value = fn(x)
except:
let e = getCurrentException()
result.exception =
some(ExceptionInfo(msg: e.msg, stacktrace: e.getStackTrace(), name: $e.name))
result.setLen(data.len)
var
nthreads = min(nthreads, data.len)
threads: seq[
Thread[
tuple[
ctx: ptr ParMap[I, ParResult[O]],
id: int,
fn: proc(x: I): ParResult[O] {.gcsafe.},
]
]
] = @[]
ctx = ParMap[I, ParResult[O]](
input: cast[ptr UncheckedArray[I]](data[0].addr),
n: data.len,
nthreads: nthreads,
output: cast[ptr UncheckedArray[ParResult[O]]](result[0].addr),
)
ctx.idx.setLen(nthreads)
ctx.stole.setLen(nthreads)
for i, idx in ctx.idx.mpairs:
idx.store(i)
threads.setLen(nthreads)
for i in 0 ..< nthreads:
createThread(threads[i], parMapWorker, (ctx.addr, i, fnResult))
joinThreads(threads)
let N = ctx.n div ctx.nthreads
for j in 0 ..< ctx.nthreads:
let
ix = ctx.idx[j].load(moRelaxed)
ii = (ix - j - nthreads) div nthreads
stderr.eraseLine()
stderr.write &"{j}: {100*ii/N:.1f}% - {ii} / {N} - idx={ix}, stole: {ctx.stole[j]}\n"
stderr.write "Done\n"
import std/[sugar, sequtils], unittest, ../src/parmap
test "double with parMap":
var xs = (1 .. 7).toSeq
var gt = collect:
for i in 1 .. 100:
i * 2
for nthreads in 1 .. 10:
var ys = parMap(
xs,
nthreads,
proc(x: int): int =
x * 2,
)
check:
ys.len == xs.len
for i, y in ys:
check:
y.exception.isNone
y.value == gt[i]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment