Skip to content

Instantly share code, notes, and snippets.

@Magicalbat
Created August 28, 2024 21:22
Show Gist options
  • Select an option

  • Save Magicalbat/55c787f70aeb59e5382e8753ab48649c to your computer and use it in GitHub Desktop.

Select an option

Save Magicalbat/55c787f70aeb59e5382e8753ab48649c to your computer and use it in GitHub Desktop.
Reverse Mode Automatic Differentiation Example
/*
To run, copy main.c into a directory, and copy
mg_arena.h from https://github.com/Magicalbat/mg-libraries
into the same directory.
This is more for me to look back at. I would
not use this as reference material as a lot of
the code is pretty sloppy.
*/
#include <stdio.h>
#include <stdint.h>
#include <math.h>
#define MG_ARENA_IMPL
#include "mg_arena.h"
typedef uint32_t u32;
typedef float f32;
typedef double f64;
typedef enum {
AD_OPER_NONE = 0,
AD_OPER_ADD,
AD_OPER_SUB,
AD_OPER_MUL,
AD_OPER_DIV,
AD_OPER_POW,
AD_OPER_EXP,
AD_OPER_LOG,
AD_OPER_SIN,
AD_OPER_COS,
} ad_oper;
#define MAX_CHILDREN 8
typedef struct ad_var {
f32 num;
f32 deriv;
ad_oper oper;
u32 num_children;
struct ad_var* children[MAX_CHILDREN];
} ad_var;
typedef struct {
u32 size;
ad_var** nodes;
} ad_var_arr;
ad_var ad_var_create(f32 num);
ad_var ad_var_add(ad_var* a, ad_var* b);
ad_var ad_var_sub(ad_var* a, ad_var* b);
ad_var ad_var_mul(ad_var* a, ad_var* b);
ad_var ad_var_div(ad_var* a, ad_var* b);
ad_var ad_var_pow(ad_var* a, ad_var* b);
ad_var ad_var_exp(ad_var* x);
ad_var ad_var_log(ad_var* x);
ad_var ad_var_sin(ad_var* x);
ad_var ad_var_cos(ad_var* x);
u32 ad_var_num_nodes(ad_var* root);
ad_var_arr ad_var_topo_sort(mg_arena* arena, ad_var* root);
void ad_var_deriv(ad_var* root);
void mga_on_error(mga_error err) {
fprintf(stderr, "MGA Error %d: %s\n", err.code, err.msg);
}
int main(void) {
mga_desc desc = {
.desired_max_size = MGA_MiB(64),
.desired_block_size = MGA_KiB(256),
.error_callback = mga_on_error
};
mg_arena* perm_arena = mga_create(&desc);
mga_scratch_set_desc(&desc);
ad_var x1 = ad_var_create(2);
ad_var x2 = ad_var_create(5);
ad_var a = ad_var_log(&x1);
ad_var b = ad_var_mul(&x1, &x2);
ad_var c = ad_var_sin(&x2);
ad_var d = ad_var_add(&a, &b);
ad_var e = ad_var_sub(&d, &c);
ad_var_deriv(&e);
ad_var_arr arr = ad_var_topo_sort(perm_arena, &e);
for (u32 i = 0; i < arr.size; i++) {
printf("%f %f\n", arr.nodes[i]->num, arr.nodes[i]->deriv);
}
mga_destroy(perm_arena);
return 0;
}
ad_var ad_var_create(f32 num) { return (ad_var){ .num = num }; }
ad_var ad_var_add(ad_var* a, ad_var* b) {
ad_var out = {
.num = a->num + b->num,
.num_children = 2,
.oper = AD_OPER_ADD,
};
out.children[0] = a;
out.children[1] = b;
return out;
}
ad_var ad_var_sub(ad_var* a, ad_var* b) {
ad_var out = {
.num = a->num - b->num,
.num_children = 2,
.oper = AD_OPER_SUB,
};
out.children[0] = a;
out.children[1] = b;
return out;
}
ad_var ad_var_mul(ad_var* a, ad_var* b) {
ad_var out = {
.num = a->num * b->num,
.num_children = 2,
.oper = AD_OPER_MUL,
};
out.children[0] = a;
out.children[1] = b;
return out;
}
ad_var ad_var_div(ad_var* a, ad_var* b) {
ad_var out = {
.num = a->num / b->num,
.num_children = 2,
.oper = AD_OPER_DIV,
};
out.children[0] = a;
out.children[1] = b;
return out;
}
ad_var ad_var_pow(ad_var* a, ad_var* b) {
ad_var out = {
.num = powf(a->num, b->num),
.num_children = 2,
.oper = AD_OPER_POW
};
out.children[0] = a;
out.children[1] = b;
return out;
}
ad_var ad_var_exp(ad_var* x) {
ad_var out = {
.num = expf(x->num),
.num_children = 1,
.oper = AD_OPER_EXP
};
out.children[0] = x;
return out;
}
ad_var ad_var_log(ad_var* x) {
ad_var out = {
.num = logf(x->num),
.num_children = 1,
.oper = AD_OPER_LOG
};
out.children[0] = x;
return out;
}
ad_var ad_var_sin(ad_var* x) {
ad_var out = {
.num = sinf(x->num),
.num_children = 1,
.oper = AD_OPER_SIN
};
out.children[0] = x;
return out;
}
ad_var ad_var_cos(ad_var* x) {
ad_var out = {
.num = cosf(x->num),
.num_children = 1,
.oper = AD_OPER_COS
};
out.children[0] = x;
return out;
}
void _ad_var_count_util_0(ad_var* node, u32* count) {
*count += node->num_children;
for (u32 i = 0; i < node->num_children; i++) {
_ad_var_count_util_0(node->children[i], count);
}
}
void _ad_var_count_util_1(ad_var* node, ad_var** visited, u32* unique_nodes) {
for (u32 i = 0; i < *unique_nodes; i++) {
if (node == visited[i])
return;
}
visited[(*unique_nodes)++] = node;
for (u32 i = 0; i < node->num_children; i++) {
_ad_var_count_util_1(node->children[i], visited, unique_nodes);
}
}
u32 ad_var_num_nodes(ad_var* root) {
u32 max_nodes = 1;
_ad_var_count_util_0(root, &max_nodes);
mga_temp scratch = mga_scratch_get(NULL, 0);
u32 unique_nodes = 0;
ad_var** visited = MGA_PUSH_ZERO_ARRAY(scratch.arena, ad_var*, max_nodes);
_ad_var_count_util_1(root, visited, &unique_nodes);
mga_scratch_release(scratch);
return unique_nodes;
}
void _ad_var_topo_util(ad_var* node, ad_var** stack, u32* stack_size, ad_var** visited, u32* num_visited) {
for (u32 i = 0; i < *num_visited; i++) {
if (node == visited[i])
return;
}
visited[(*num_visited)++] = node;
for (u32 i = 0; i < node->num_children; i++) {
_ad_var_topo_util(node->children[i], stack, stack_size, visited, num_visited);
}
stack[(*stack_size)++] = node;
}
ad_var_arr ad_var_topo_sort(mg_arena* arena, ad_var* root) {
u32 num_nodes = ad_var_num_nodes(root);
u32 stack_size = 0;
ad_var** stack = MGA_PUSH_ZERO_ARRAY(arena, ad_var*, num_nodes);
mga_temp scratch = mga_scratch_get(NULL, 0);
u32 num_visited = 0;
ad_var** visited = MGA_PUSH_ZERO_ARRAY(arena, ad_var*, num_nodes);
_ad_var_topo_util(root, stack, &stack_size, visited, &num_visited);
mga_scratch_release(scratch);
// Renaming for clarify
ad_var** out = stack;
// Reversing list
for (u32 i = 0; i < num_nodes/2; i++) {
ad_var* tmp = out[i];
out[i] = out[num_nodes-i-1];
out[num_nodes-i-1] = tmp;
}
return (ad_var_arr){
.size = num_nodes,
.nodes = out
};
}
void ad_var_deriv(ad_var* y) {
y->deriv = 1;
mga_temp scratch = mga_scratch_get(NULL, 0);
ad_var_arr node_arr = ad_var_topo_sort(scratch.arena, y);
for (u32 i = 0; i < node_arr.size; i++) {
ad_var* node = node_arr.nodes[i];
ad_var* a = node->children[0];
ad_var* b = node->children[1];
switch (node->oper) {
case AD_OPER_NONE: break;
case AD_OPER_ADD: {
a->deriv += node->deriv;
b->deriv += node->deriv;
} break;
case AD_OPER_SUB: {
a->deriv += node->deriv;
b->deriv -= node->deriv;
} break;
case AD_OPER_MUL: {
a->deriv += b->num * node->deriv;
b->deriv += a->num * node->deriv;
} break;
case AD_OPER_DIV: {
a->deriv += node->deriv / b->num;
b->deriv += node->deriv * (-a->num / (b->num * b->num));
} break;
case AD_OPER_POW: {
a->deriv += node->deriv * b->num * powf(a->num, b->num - 1.0f);
b->deriv += node->deriv * logf(a->num) * powf(a->num, b->num);
} break;
case AD_OPER_EXP: {
a->deriv += node->deriv * a->num;
} break;
case AD_OPER_LOG: {
a->deriv += node->deriv / a->num;
} break;
case AD_OPER_SIN: {
a->deriv += node->deriv * cosf(a->num);
} break;
case AD_OPER_COS: {
a->deriv += node->deriv * -sinf(a->num);
} break;
}
}
mga_scratch_release(scratch);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment