Skip to content

Instantly share code, notes, and snippets.

@ihrke
Created March 28, 2018 13:01
Show Gist options
  • Select an option

  • Save ihrke/4d7a239d3db48463ccb371d235c3b380 to your computer and use it in GitHub Desktop.

Select an option

Save ihrke/4d7a239d3db48463ccb371d235c3b380 to your computer and use it in GitHub Desktop.
RRM Stan implementation
library(ProjectTemplate)
load.project()
theme_set(theme_bw())
library(rstan)
options(mc.cores = parallel::detectCores())
bname<-tools::file_path_sans_ext(basename(this.file.name()))
### MCMC parameters iterations
n.chains=4
n.cores=4
n.iter=1000
n.warmup=500
##================================================================================================
# data preparation
sart_exp1 %>% filter(subj==1, type %in% c("go","stop")) ->d
#
data.stan=list(
n=dim(d)[1],
RT=d$RT,
go_nogo=((droplevels(d$type) %>% relevel(ref="stop") %>% as.integer)-1)
)
## only go for testing
#sart_exp1 %>% filter(subj==1, type %in% c("go")) ->d
#data.stan=list(
# n=dim(d)[1],
# RT=d$RT,
# go_nogo=(((d$type) %>% relevel(ref="stop") %>% as.integer)-1)
#)
##================================================================================================
# model fitting
if(!is.cached.var("mod", base=bname)){
n.cores=1; n.chains=1; n.iter=100; n.warmup=10
mod = stan(file=sprintf("src/%s.stan", bname), data = data.stan, cores=n.cores, chains = n.chains, iter = n.iter, warmup = n.warmup)
#cache.var("mod", bname)
} else {
mod <- load.cache.var("mod",bname)
}
##================================================================================================
## diagnostics
pdf(file=plot.filename("diagnostics.pdf", base=bname), onefile=TRUE)
stan_rhat(mod) %>% print
stan_ess(mod) %>% print
npages=8
modm = as.matrix(mod)
vnames=colnames(modm)
nvar=dim(modm)[2]
for(p in split(1:nvar, ceiling(seq_along(1:nvar)/(nvar/npages)))){
stan_trace(mod, pars=vnames[p]) %>% print
# bayesplot::mcmc_trace(modm[,p]) %>% print
}
dev.off()
##================================================================================================
# parameters
functions{
real lba_pdf(real t, real b, real A, real v_pdf, real s) {
//PDF of the LBA model
real b_A_tv_ts;
real b_tv_ts;
real term_1b;
real term_2b;
real term_3b;
real term_4b;
real pdf;
b_A_tv_ts = (b - A - t * v_pdf)/(t * s);
b_tv_ts = (b - t * v_pdf)/(t * s);
term_1b = v_pdf * Phi(b_A_tv_ts);
term_2b = s * exp(normal_lpdf(fabs(b_A_tv_ts) | 0, 1));
term_3b = v_pdf * Phi(b_tv_ts);
term_4b = s * exp(normal_lpdf(fabs(b_tv_ts) | 0, 1));
pdf = (1/A) * (-term_1b + term_2b + term_3b - term_4b);
return pdf;
}
real lba_cdf(real t, real b, real A, real v_cdf, real s) {
//CDF of the LBA model
real b_A_tv;
real b_tv;
real ts;
real term_1a;
real term_2a;
real term_3a;
real term_4a;
real cdf;
b_A_tv = b - A - t * v_cdf;
b_tv = b - t * v_cdf;
ts = t * s;
term_1a = b_A_tv/A * Phi(b_A_tv/ts);
term_2a = b_tv/A * Phi(b_tv/ts);
term_3a = ts/A * exp(normal_lpdf(fabs(b_A_tv/ts) | 0, 1));
term_4a = ts/A * exp(normal_lpdf(fabs(b_tv/ts) | 0, 1));
cdf = 1 + term_1a - term_2a + term_3a - term_4a;
return cdf;
}
}
data {
int<lower=1> n; // ntrials
int<lower=0,upper=1> go_nogo[n]; // 1=go-trial, 0=nogo-trial
vector[n] RT; // RT<0 means nogo
}
transformed data {
// the place to define constants according to Bob Carpenter: http://discourse.mc-stan.org/t/is-there-something-like-define-in-stan/869/3
int<lower=0> n_int; // number of integration steps
real<lower=0> upper_int; // upper integration limit
real<lower=0> d_nogo_go;
d_nogo_go=1.0;
n_int = 2000;
upper_int=2.0; // in sec
}
parameters {
// drift-rates
real<lower=0> d_go_go;
real<lower=0> d_go_nogo;
real<lower=0> d_nogo_nogo;
// d_nogo_go=0 because there are almost no trials to estimate this anyways
// b=1 (scaling parameter)
real<lower=0, upper=1> A; // starting point
real<lower=0> s; // drift-rate variability
real<lower=0> t0; // non-decision time
real<lower=0> k; // Weibull shape
real<lower=0> lambda; // Weibull rate
}
model {
d_go_go ~ normal(0,1);
d_go_nogo ~ normal(0,1);
d_nogo_go ~ normal(0,1);
s ~ normal(0,.2);
t0 ~ normal(0,.1);
k ~ normal(0,1);
A ~ normal(0,.5);
lambda ~ normal(0,1);
print("dgg=",d_go_go, " dgn=", d_go_nogo, " dng=", d_nogo_go, " s=", s, " t0=", t0, " k=", k, " A=", A, " lambda=", lambda);
{
real mytarget=0.0;
real tmp;
real pnogo_go=-1.0;
real pnogo_nogo=-1.0;
for( i in 1:n ){
if(go_nogo[i]==1){ // GO-trial
if(RT[i]>=0){ // go-response given
tmp=log(
lba_pdf(RT[i], 1, A, d_go_go, s) *
(1-lba_cdf(RT[i], 1, A, d_nogo_go, s))*
(1-weibull_cdf(RT[i],k,lambda)) +
exp(weibull_lpdf(RT[i] | k,lambda))*
(1-lba_cdf(RT[i], 1, A, d_go_go,s))*
(1-lba_cdf(RT[i], 1, A, d_nogo_go,s))
);
mytarget += tmp;
target += tmp;
/*print("go-go: target=", log(
lba_pdf(RT[i], 1, A, d_go_go, s) *
(1-lba_cdf(RT[i], 1, A, d_nogo_go, s))*
(1-weibull_cdf(RT[i],k,lambda)) +
exp(weibull_lpdf(RT[i] | k,lambda))*
(1-lba_cdf(RT[i], 1, A, d_go_go,s))*
(1-lba_cdf(RT[i], 1, A, d_nogo_go,s))
));*/
} else { // nogo-response given
if( pnogo_go<0 ){ // only calculate the first time
// integration from 0 to upper_int
real x1; // sub-intervals to integrate over
real x2;
real f1; // function evaluations f(x1),f(x2)
real f2;
pnogo_go=0.0;
for( j in 1:(n_int) ){
x1=(j-1)*(upper_int/n_int);
x2=x1+(upper_int/n_int);
f1=lba_pdf(x1,1,A,d_nogo_go,s)*(1-lba_cdf(x1,1,A,d_go_go,s))*(1-weibull_cdf(x1,k,lambda));
f2=lba_pdf(x2,1,A,d_nogo_go,s)*(1-lba_cdf(x2,1,A,d_go_go,s))*(1-weibull_cdf(x2,k,lambda));
pnogo_go = pnogo_go + ( (f1+f2)/2.0 * (x2-x1) );
}
}
target += log( pnogo_go );
mytarget += log( pnogo_go );
//print("go-nogo: log(pnogo)=", log(pnogo));
}
} else { // NOGO trial
if(RT[i]>=0){ // go-response given
tmp=log(
lba_pdf(RT[i], 1, A, d_go_nogo, s) *
(1-lba_cdf(RT[i], 1, A, d_nogo_nogo, s))*
(1-weibull_cdf(RT[i],k,lambda)) +
exp(weibull_lpdf(RT[i] | k,lambda))*
(1-lba_cdf(RT[i], 1, A, d_go_nogo,s))*
(1-lba_cdf(RT[i], 1, A, d_nogo_nogo,s))
);
target += tmp;
mytarget += tmp;
/* print("nogo-go: target=", log(
lba_pdf(RT[i], 1, A, d_go_nogo, s) *
(1-lba_cdf(RT[i], 1, A, d_nogo_nogo, s))*
(1-weibull_cdf(RT[i],k,lambda)) +
exp(weibull_lpdf(RT[i] | k,lambda))*
(1-lba_cdf(RT[i], 1, A, d_go_nogo,s))*
(1-lba_cdf(RT[i], 1, A, d_nogo_nogo,s))
)); */
} else { // nogo-response given
if( pnogo_nogo<0 ){ // only calculate the first time
// integration from 0 to upper_int
real x1; // sub-intervals to integrate over
real x2;
real f1; // function evaluations f(x1),f(x2)
real f2;
pnogo_nogo=0.0;
for( j in 1:(n_int) ){
x1=(j-1)*(upper_int/n_int);
x2=x1+(upper_int/n_int);
f1=lba_pdf(x1,1,A,d_nogo_nogo,s)*(1-lba_cdf(x1,1,A,d_go_nogo,s))*(1-weibull_cdf(x1,k,lambda));
f2=lba_pdf(x2,1,A,d_nogo_nogo,s)*(1-lba_cdf(x2,1,A,d_go_nogo,s))*(1-weibull_cdf(x2,k,lambda));
pnogo_nogo = pnogo_nogo + ( (f1+f2)/2.0 * (x2-x1) );
}
}
target += log( pnogo_nogo );
mytarget += log(pnogo_nogo);
//print("nogo-nogo: log(pnogo)=", log(pnogo));
}
}
}
print("target=",mytarget, " pnogo_go=", pnogo_go, " pnogo_nogo=", pnogo_nogo);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment