Created
March 28, 2018 13:01
-
-
Save ihrke/4d7a239d3db48463ccb371d235c3b380 to your computer and use it in GitHub Desktop.
RRM Stan implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| library(ProjectTemplate) | |
| load.project() | |
| theme_set(theme_bw()) | |
| library(rstan) | |
| options(mc.cores = parallel::detectCores()) | |
| bname<-tools::file_path_sans_ext(basename(this.file.name())) | |
| ### MCMC parameters iterations | |
| n.chains=4 | |
| n.cores=4 | |
| n.iter=1000 | |
| n.warmup=500 | |
| ##================================================================================================ | |
| # data preparation | |
| sart_exp1 %>% filter(subj==1, type %in% c("go","stop")) ->d | |
| # | |
| data.stan=list( | |
| n=dim(d)[1], | |
| RT=d$RT, | |
| go_nogo=((droplevels(d$type) %>% relevel(ref="stop") %>% as.integer)-1) | |
| ) | |
| ## only go for testing | |
| #sart_exp1 %>% filter(subj==1, type %in% c("go")) ->d | |
| #data.stan=list( | |
| # n=dim(d)[1], | |
| # RT=d$RT, | |
| # go_nogo=(((d$type) %>% relevel(ref="stop") %>% as.integer)-1) | |
| #) | |
| ##================================================================================================ | |
| # model fitting | |
| if(!is.cached.var("mod", base=bname)){ | |
| n.cores=1; n.chains=1; n.iter=100; n.warmup=10 | |
| mod = stan(file=sprintf("src/%s.stan", bname), data = data.stan, cores=n.cores, chains = n.chains, iter = n.iter, warmup = n.warmup) | |
| #cache.var("mod", bname) | |
| } else { | |
| mod <- load.cache.var("mod",bname) | |
| } | |
| ##================================================================================================ | |
| ## diagnostics | |
| pdf(file=plot.filename("diagnostics.pdf", base=bname), onefile=TRUE) | |
| stan_rhat(mod) %>% print | |
| stan_ess(mod) %>% print | |
| npages=8 | |
| modm = as.matrix(mod) | |
| vnames=colnames(modm) | |
| nvar=dim(modm)[2] | |
| for(p in split(1:nvar, ceiling(seq_along(1:nvar)/(nvar/npages)))){ | |
| stan_trace(mod, pars=vnames[p]) %>% print | |
| # bayesplot::mcmc_trace(modm[,p]) %>% print | |
| } | |
| dev.off() | |
| ##================================================================================================ | |
| # parameters |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| functions{ | |
| real lba_pdf(real t, real b, real A, real v_pdf, real s) { | |
| //PDF of the LBA model | |
| real b_A_tv_ts; | |
| real b_tv_ts; | |
| real term_1b; | |
| real term_2b; | |
| real term_3b; | |
| real term_4b; | |
| real pdf; | |
| b_A_tv_ts = (b - A - t * v_pdf)/(t * s); | |
| b_tv_ts = (b - t * v_pdf)/(t * s); | |
| term_1b = v_pdf * Phi(b_A_tv_ts); | |
| term_2b = s * exp(normal_lpdf(fabs(b_A_tv_ts) | 0, 1)); | |
| term_3b = v_pdf * Phi(b_tv_ts); | |
| term_4b = s * exp(normal_lpdf(fabs(b_tv_ts) | 0, 1)); | |
| pdf = (1/A) * (-term_1b + term_2b + term_3b - term_4b); | |
| return pdf; | |
| } | |
| real lba_cdf(real t, real b, real A, real v_cdf, real s) { | |
| //CDF of the LBA model | |
| real b_A_tv; | |
| real b_tv; | |
| real ts; | |
| real term_1a; | |
| real term_2a; | |
| real term_3a; | |
| real term_4a; | |
| real cdf; | |
| b_A_tv = b - A - t * v_cdf; | |
| b_tv = b - t * v_cdf; | |
| ts = t * s; | |
| term_1a = b_A_tv/A * Phi(b_A_tv/ts); | |
| term_2a = b_tv/A * Phi(b_tv/ts); | |
| term_3a = ts/A * exp(normal_lpdf(fabs(b_A_tv/ts) | 0, 1)); | |
| term_4a = ts/A * exp(normal_lpdf(fabs(b_tv/ts) | 0, 1)); | |
| cdf = 1 + term_1a - term_2a + term_3a - term_4a; | |
| return cdf; | |
| } | |
| } | |
| data { | |
| int<lower=1> n; // ntrials | |
| int<lower=0,upper=1> go_nogo[n]; // 1=go-trial, 0=nogo-trial | |
| vector[n] RT; // RT<0 means nogo | |
| } | |
| transformed data { | |
| // the place to define constants according to Bob Carpenter: http://discourse.mc-stan.org/t/is-there-something-like-define-in-stan/869/3 | |
| int<lower=0> n_int; // number of integration steps | |
| real<lower=0> upper_int; // upper integration limit | |
| real<lower=0> d_nogo_go; | |
| d_nogo_go=1.0; | |
| n_int = 2000; | |
| upper_int=2.0; // in sec | |
| } | |
| parameters { | |
| // drift-rates | |
| real<lower=0> d_go_go; | |
| real<lower=0> d_go_nogo; | |
| real<lower=0> d_nogo_nogo; | |
| // d_nogo_go=0 because there are almost no trials to estimate this anyways | |
| // b=1 (scaling parameter) | |
| real<lower=0, upper=1> A; // starting point | |
| real<lower=0> s; // drift-rate variability | |
| real<lower=0> t0; // non-decision time | |
| real<lower=0> k; // Weibull shape | |
| real<lower=0> lambda; // Weibull rate | |
| } | |
| model { | |
| d_go_go ~ normal(0,1); | |
| d_go_nogo ~ normal(0,1); | |
| d_nogo_go ~ normal(0,1); | |
| s ~ normal(0,.2); | |
| t0 ~ normal(0,.1); | |
| k ~ normal(0,1); | |
| A ~ normal(0,.5); | |
| lambda ~ normal(0,1); | |
| print("dgg=",d_go_go, " dgn=", d_go_nogo, " dng=", d_nogo_go, " s=", s, " t0=", t0, " k=", k, " A=", A, " lambda=", lambda); | |
| { | |
| real mytarget=0.0; | |
| real tmp; | |
| real pnogo_go=-1.0; | |
| real pnogo_nogo=-1.0; | |
| for( i in 1:n ){ | |
| if(go_nogo[i]==1){ // GO-trial | |
| if(RT[i]>=0){ // go-response given | |
| tmp=log( | |
| lba_pdf(RT[i], 1, A, d_go_go, s) * | |
| (1-lba_cdf(RT[i], 1, A, d_nogo_go, s))* | |
| (1-weibull_cdf(RT[i],k,lambda)) + | |
| exp(weibull_lpdf(RT[i] | k,lambda))* | |
| (1-lba_cdf(RT[i], 1, A, d_go_go,s))* | |
| (1-lba_cdf(RT[i], 1, A, d_nogo_go,s)) | |
| ); | |
| mytarget += tmp; | |
| target += tmp; | |
| /*print("go-go: target=", log( | |
| lba_pdf(RT[i], 1, A, d_go_go, s) * | |
| (1-lba_cdf(RT[i], 1, A, d_nogo_go, s))* | |
| (1-weibull_cdf(RT[i],k,lambda)) + | |
| exp(weibull_lpdf(RT[i] | k,lambda))* | |
| (1-lba_cdf(RT[i], 1, A, d_go_go,s))* | |
| (1-lba_cdf(RT[i], 1, A, d_nogo_go,s)) | |
| ));*/ | |
| } else { // nogo-response given | |
| if( pnogo_go<0 ){ // only calculate the first time | |
| // integration from 0 to upper_int | |
| real x1; // sub-intervals to integrate over | |
| real x2; | |
| real f1; // function evaluations f(x1),f(x2) | |
| real f2; | |
| pnogo_go=0.0; | |
| for( j in 1:(n_int) ){ | |
| x1=(j-1)*(upper_int/n_int); | |
| x2=x1+(upper_int/n_int); | |
| f1=lba_pdf(x1,1,A,d_nogo_go,s)*(1-lba_cdf(x1,1,A,d_go_go,s))*(1-weibull_cdf(x1,k,lambda)); | |
| f2=lba_pdf(x2,1,A,d_nogo_go,s)*(1-lba_cdf(x2,1,A,d_go_go,s))*(1-weibull_cdf(x2,k,lambda)); | |
| pnogo_go = pnogo_go + ( (f1+f2)/2.0 * (x2-x1) ); | |
| } | |
| } | |
| target += log( pnogo_go ); | |
| mytarget += log( pnogo_go ); | |
| //print("go-nogo: log(pnogo)=", log(pnogo)); | |
| } | |
| } else { // NOGO trial | |
| if(RT[i]>=0){ // go-response given | |
| tmp=log( | |
| lba_pdf(RT[i], 1, A, d_go_nogo, s) * | |
| (1-lba_cdf(RT[i], 1, A, d_nogo_nogo, s))* | |
| (1-weibull_cdf(RT[i],k,lambda)) + | |
| exp(weibull_lpdf(RT[i] | k,lambda))* | |
| (1-lba_cdf(RT[i], 1, A, d_go_nogo,s))* | |
| (1-lba_cdf(RT[i], 1, A, d_nogo_nogo,s)) | |
| ); | |
| target += tmp; | |
| mytarget += tmp; | |
| /* print("nogo-go: target=", log( | |
| lba_pdf(RT[i], 1, A, d_go_nogo, s) * | |
| (1-lba_cdf(RT[i], 1, A, d_nogo_nogo, s))* | |
| (1-weibull_cdf(RT[i],k,lambda)) + | |
| exp(weibull_lpdf(RT[i] | k,lambda))* | |
| (1-lba_cdf(RT[i], 1, A, d_go_nogo,s))* | |
| (1-lba_cdf(RT[i], 1, A, d_nogo_nogo,s)) | |
| )); */ | |
| } else { // nogo-response given | |
| if( pnogo_nogo<0 ){ // only calculate the first time | |
| // integration from 0 to upper_int | |
| real x1; // sub-intervals to integrate over | |
| real x2; | |
| real f1; // function evaluations f(x1),f(x2) | |
| real f2; | |
| pnogo_nogo=0.0; | |
| for( j in 1:(n_int) ){ | |
| x1=(j-1)*(upper_int/n_int); | |
| x2=x1+(upper_int/n_int); | |
| f1=lba_pdf(x1,1,A,d_nogo_nogo,s)*(1-lba_cdf(x1,1,A,d_go_nogo,s))*(1-weibull_cdf(x1,k,lambda)); | |
| f2=lba_pdf(x2,1,A,d_nogo_nogo,s)*(1-lba_cdf(x2,1,A,d_go_nogo,s))*(1-weibull_cdf(x2,k,lambda)); | |
| pnogo_nogo = pnogo_nogo + ( (f1+f2)/2.0 * (x2-x1) ); | |
| } | |
| } | |
| target += log( pnogo_nogo ); | |
| mytarget += log(pnogo_nogo); | |
| //print("nogo-nogo: log(pnogo)=", log(pnogo)); | |
| } | |
| } | |
| } | |
| print("target=",mytarget, " pnogo_go=", pnogo_go, " pnogo_nogo=", pnogo_nogo); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment