Skip to content

Instantly share code, notes, and snippets.

@lnalborczyk
Last active January 2, 2023 13:05
Show Gist options
  • Select an option

  • Save lnalborczyk/0c099d960ba23480bb437d4acb2c93ec to your computer and use it in GitHub Desktop.

Select an option

Save lnalborczyk/0c099d960ba23480bb437d4acb2c93ec to your computer and use it in GitHub Desktop.
Illustrating the Nelder & Mead (1965)'s optimisation method using gganimate
############################################################################
# estimation via the base R optim function
# code adapted from Farrell & Lewandowsky (2018, ch.3)
# https://github.com/psy-farrell/computational-modelling
#################################################################
library(hrbrthemes)
library(tidyverse)
library(gganimate)
library(magick)
# define parameters to generate data
nDataPts <- 20
rho <- .8
intercept <- 0
# generate synthetic data
data <- matrix(0, nDataPts, 2)
data[, 2] <- rnorm(nDataPts)
data[, 1] <- rnorm(nDataPts) * sqrt(1.0 - rho^2) + data[, 2] * rho + intercept
# assign starting values
startParms <- c(-1., .2)
names(startParms) <- c("b1", "b0")
# initialising data.frame to store intermediate parameters values
res <- data.frame(b1 = numeric(), b0 = numeric(), rmsd = numeric() )
iter <- 1
# obtain current predictions and compute discrepancy
rmsd <- function(parms, data1) {
# retrieve fitted values
preds <- parms["b0"] + parms["b1"] * data[, 2]
# compute RMSD
rmsd <- sqrt(sum((preds - data1[, 1])^2) / length(preds) )
# saving output of each iteration
res[iter, ] <<- c(parms, rmsd)
iter <<- iter + 1
# return rmsd
return(rmsd)
}
# obtain parameter estimates (via Nelder & Mead (1965)'s method)
optim(startParms, rmsd, data1 = data, control = list(trace = 0) )
# plotting it
plot1 <-
res %>%
rowid_to_column(var = "iteration") %>%
ggplot() +
geom_point(
data = data %>% data.frame,
aes(x = data[, 1], y = data[, 2] ),
shape = 21, size = 3, fill = "grey60"
) +
geom_abline(
aes(intercept = b0, slope = b1)
) +
theme_ipsum_rc(base_size = 12) +
transition_reveal(along = iteration) +
labs(
title = "Optimisation via Nelder & Mead (1965)'s method",
subtitle = "Animated version of Figure 3.3 from Farrell & Lewandowsky (2018, p. 56)",
x = "x", y = "y"
)
# save animation as a .gif file
anim_save(filename = "optimisation1.gif", animation = plot1)
# plotting RMSD function
rmsd <- function(x, y) {
# retrieve fitted values
preds <- x + y * data[, 2]
# compute RMSD
rmsd <- sqrt(sum((preds - data[, 1])^2) / length(preds) )
return(rmsd)
}
# compute RMSD for a grid of possible values
rmsd_grid <-
crossing(x = seq(-2, 2, 0.1), y = seq(-2, 2, 0.1) ) %>%
data.frame %>%
rowid_to_column(var = "id") %>%
group_by(id) %>%
mutate(z = rmsd(x, y ) ) %>%
ungroup %>%
data.frame
plot2 <-
res %>%
rowid_to_column(var = "iteration") %>%
ggplot() +
geom_raster(
data = rmsd_grid,
aes(x = x, y = y, z = z, colour = 1 - z, fill = 1 - z),
show.legend = FALSE
) +
geom_contour(
data = rmsd_grid,
aes(x = x, y = y, z = z, colour = 1 - z, fill = 1 - z),
show.legend = FALSE
) +
geom_point(aes(x = b0, y = b1), size = 3) +
# forcing the origin at zero
scale_x_continuous(expand = c(0, 0) ) +
scale_y_continuous(expand = c(0, 0) ) +
theme_ipsum_rc(base_size = 12) +
transition_reveal(along = iteration) +
labs(
title = "Contour plot of the cost function (RMSD)",
subtitle = "Example from Farrell & Lewandowsky (2018)",
x = "intercept", y = "slope"
)
# save animation as a .gif file
anim_save(filename = "optimisation2.gif", animation = plot2)
############################################################################
# combine plots together
# https://github.com/thomasp85/gganimate/wiki/Animation-Composition
######################################################################
plot1_mgif <- image_read("optimisation1.gif")
plot2_mgif <- image_read("optimisation2.gif")
new_gif <- image_append(c(plot1_mgif[1], plot2_mgif[1]) )
for (i in 2:100) {
combined <- image_append(c(plot1_mgif[i], plot2_mgif[i]))
new_gif <- c(new_gif, combined)
}
# save animation as a .gif file
image_write(new_gif, "optimisation.gif")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment