diff --git a/Project.toml b/Project.toml index 25df1be..8d0023d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ActionModels" uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" authors = ["Peter Thestrup Waade ptw@cas.au.dk", "Anna Hedvig Møller hedvig.2808@gmail.com", "Jacopo Comoglio jacopo.comoglio@gmail.com", "Christoph Mathys chmathys@cas.au.dk"] -version = "0.6.2" +version = "0.6.3" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" diff --git a/src/ActionModels.jl b/src/ActionModels.jl index 137481c..2531577 100644 --- a/src/ActionModels.jl +++ b/src/ActionModels.jl @@ -10,7 +10,7 @@ using ProgressMeter, Distributed #TODO: get rid of this (only needed for paramet #Export functions export Agent, RejectParameters, InitialStateParameter, ParameterGroup export init_agent, premade_agent, warn_premade_defaults, multiple_actions, check_agent -export simple_statistical_model, +export independent_agents_population_model, create_model, fit_model, parameter_recovery, single_recovery export plot_parameter_distribution, plot_predictive_simulation, plot_trajectory, plot_trajectory! @@ -44,13 +44,14 @@ include("create_agent/create_premade_agent.jl") include("create_agent/multiple_actions.jl") include("create_agent/check_agent.jl") #Functions for fitting agents to data -include("fitting/helper_functions.jl") -include("fitting/extract_quantities.jl") include("fitting/create_model.jl") -include("fitting/simple_statistical_model.jl") -include("fitting/single_agent_statistical_model.jl") +include("fitting/agent_model.jl") +include("fitting/population_models/independent_agents_population_model.jl") +include("fitting/population_models/single_agent_population_model.jl") include("fitting/fit_model.jl") include("fitting/parameter_recovery.jl") +include("fitting/helper_functions.jl") +include("fitting/extract_quantities.jl") #include("fitting/prefit_checks.jl") #Plotting functions for agents diff --git a/src/fitting/agent_model.jl b/src/fitting/agent_model.jl new file mode 100644 index 0000000..96745e0 --- /dev/null +++ b/src/fitting/agent_model.jl @@ -0,0 +1,220 @@ +############################################### +### WITH SINGLE ACTION / NO MISSING ACTIONS ### +############################################### +@model function agent_models(agent::Agent, agent_ids::Vector{Symbol}, parameters_per_agent::Vector{D}, inputs_per_agent::Vector{I}, actions_per_agent::Vector{Vector{R}}, actions_flattened::Vector{R}, missing_actions::Nothing) where {D<:Dict, I<:Vector, R<:Real} + + #TODO: Could use a list comprehension here to make it more efficient + #Initialize a vector for storing the action probability distributions + action_distributions = Vector(undef, length(actions_flattened)) + + #Initialize action index + action_idx = 0 + + #Go through each agent + for (agent_parameters, agent_inputs, agent_actions) in zip(parameters_per_agent, inputs_per_agent, actions_per_agent) + + #Set the agent parameters + set_parameters!(agent, agent_parameters) + reset!(agent) + + #Go through each timestep + for (input, action) in zip(agent_inputs, agent_actions) + + #Increment one action index + action_idx += 1 + + #Get the action probability distributions from the action model + @inbounds action_distributions[action_idx] = agent.action_model(agent, input) + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new vector) + action_distributions = [dist for dist in action_distributions] + + #Sample the actions from the probability distributions + actions_flattened ~ arraydist(action_distributions) +end + + +################################################## +### WITH MULTIPLE ACTIONS / NO MISSING ACTIONS ### +################################################## +@model function agent_models(agent::Agent, agent_ids::Vector{Symbol}, parameters_per_agent::Vector{D}, inputs_per_agent::Vector{I}, actions_per_agent::Vector{Matrix{R}}, actions_flattened::Matrix{R}, missing_actions::Nothing) where {D<:Dict, I<:Vector, R<:Real} + + #Initialize a vector for storing the action probability distributions + action_distributions = Matrix(undef, size(actions_flattened)...) + + #Initialize action index + action_idx = 0 + + #Go through each agent + for (agent_parameters, agent_inputs, agent_actions) in zip(parameters_per_agent, inputs_per_agent, actions_per_agent) + + #Set the agent parameters + set_parameters!(agent, agent_parameters) + reset!(agent) + + #Go through each timestep + for (input, action) in zip(agent_inputs, Tuple.(eachrow(agent_actions))) + + #Increment one action index + action_idx += 1 + + #Get the action probability distributions from the action model + @inbounds action_distributions[action_idx, :] = agent.action_model(agent, input) + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new vector) + action_distributions = [dist for dist in action_distributions] + + #Sample the actions from the probability distributions + actions_flattened ~ arraydist(action_distributions) +end + + + + + + + + + +############################################ +### WITH MISSING ACTIONS - SUPERFUNCTION ### +############################################ +@model function agent_models(agent::Agent, agent_ids::Vector{Symbol}, parameters_per_agent::Vector{D}, inputs_per_agent::Vector{I}, actions_per_agent::Vector{A}, actions_flattened::A, missing_actions::MissingActions) where {D<:Dict, I<:Vector, A<:Array} + + #For each agent + for (agent_id, agent_parameters, agent_inputs, agent_actions) in zip(agent_ids, parameters_per_agent, inputs_per_agent, actions_per_agent) + + #Fit it to the data + @submodel prefix = "$agent_id" agent_model(agent, agent_parameters, agent_inputs, agent_actions) + end +end + +################################################# +### WITH SINGLE ACTION / WITH MISSING ACTIONS ### +################################################# +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Vector{Union{Missing, R}}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Go through each timestep + for (timestep, input) in enumerate(inputs) + + #Get the action probability distributions from the action model + action_distribution = agent.action_model(agent, input) + + #Sample the action from the probability distribution + @inbounds actions[timestep] ~ action_distribution + + #Save the action to the agent in case it needs it in the future + @inbounds update_states!( + agent, + "action", + ad_val(actions[timestep]), + ) + end +end + +#################################################### +### WITH MULTIPLE ACTIONS / WITH MISSING ACTIONS ### +#################################################### +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Matrix{Union{Missing, R}}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Go through each timestep + for (timestep, input) in enumerate(inputs) + + #Get the action probability distributions from the action model + action_distributions = agent.action_model(agent, input) + + #Go through each action + for (action_idx, single_distribution) in enumerate(action_distributions) + + #Sample the action from the probability distribution + actions[timestep, action_idx] ~ + single_distribution + #TODO: can use @inbounds here when there's a check for whether the right amount of actions are produced + end + + #Add the actions to the agent in case it needs it in the future + update_states!( + agent, + "action", + ad_val.(actions[timestep, :]), + ) + #TODO: can use @inbounds here when there's a check for whether the right amount of actions are produced + end +end + +############################################### +### WITH SINGLE ACTION / NO MISSING ACTIONS ### +############################################### +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Vector{R}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Initialize a vector for storing the action probability distributions + action_distributions = Vector(undef, length(inputs)) + + #Go through each timestep + for (timestep, (input, action)) in enumerate(zip(inputs, actions)) + + #Get the action probability distributions from the action model + @inbounds action_distributions[timestep] = agent.action_model(agent, input) + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new vector) + action_distributions = [dist for dist in action_distributions] + + #Sample the actions from the probability distributions + actions ~ arraydist(action_distributions) +end + +################################################## +### WITH MULTIPLE ACTIONS / NO MISSING ACTIONS ### +################################################## + +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Matrix{R}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Initialize a matrix for storing the action probability distributions + action_distributions = Matrix(undef, size(actions)...) + + #Go through each timestep + for (timestep, (input, action)) in enumerate(zip(inputs, Tuple.(eachrow(actions)))) + + #Get the action probability distributions from the action model + action_distributions[timestep, :] = agent.action_model(agent, input) #TODO: can use @inbounds here when there's a check for whether the right amount of actions are used + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new matrix) + action_distributions = [dist for dist in action_distributions] + + #Sample the actions from the probability distributions + actions ~ arraydist(action_distributions) +end \ No newline at end of file diff --git a/src/fitting/create_model.jl b/src/fitting/create_model.jl index 2837ef8..162bc32 100644 --- a/src/fitting/create_model.jl +++ b/src/fitting/create_model.jl @@ -3,12 +3,13 @@ ########################################################################################################### function create_model( agent::Agent, - statistical_model::DynamicPPL.Model, + population_model::DynamicPPL.Model, data::DataFrame; input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T3}, - grouping_cols::Union{Vector{T3},T3} = Vector{String}(), - track_states::Bool = false, + grouping_cols::Union{Vector{T3},T3}, + check_parameter_rejections::Union{Nothing, CheckRejections} = nothing, + id_separator::String = "__", verbose::Bool = true, ) where {T1<:Union{String,Symbol},T2<:Union{String,Symbol},T3<:Union{String,Symbol}} @@ -16,172 +17,94 @@ function create_model( #Create a copy of the agent to avoid changing the original agent_model = deepcopy(agent) - #If states are to be tracked - if track_states - #Make sure the agent saves the history - set_save_history!(agent_model, true) - else - #Otherwise not - set_save_history!(agent_model, false) - end + #Turn off saving the history of states + set_save_history!(agent_model, false) ## Make sure columns are vectors of symbols ## if !(input_cols isa Vector) input_cols = [input_cols] end + input_cols = Symbol.(input_cols) + if !(action_cols isa Vector) action_cols = [action_cols] end + action_cols = Symbol.(action_cols) + if !(grouping_cols isa Vector) grouping_cols = [grouping_cols] end - input_cols = Symbol.(input_cols) - action_cols = Symbol.(action_cols) grouping_cols = Symbol.(grouping_cols) #Run checks for the model specifications check_model( agent, - statistical_model, + population_model, data; input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, - track_states = track_states, verbose = verbose, ) - + + ## Extract data ## - #One matrix per agent, for inputs and actions separately - inputs = - [Array(agent_data[:, input_cols]) for agent_data in groupby(data, grouping_cols)] - actions = - [Array(agent_data[:, action_cols]) for agent_data in groupby(data, grouping_cols)] + #If there is only one input column + if length(input_cols) == 1 + #Inputs are a vector of vectors of <:reals + inputs = [Vector(agent_data[!,first(input_cols)]) for agent_data in groupby(data, grouping_cols)] + else + #Otherwise, they are a vector of vectors of tuples + inputs = [Tuple.(eachrow(agent_data[!,input_cols])) for agent_data in groupby(data, grouping_cols)] + end + + #If there is only one action column + if length(action_cols) == 1 + #Actions are a vector of arrays (vectors if there is only one action, matrices if there are multiple) + actions = + [Vector(agent_data[!, first(action_cols)]) for agent_data in groupby(data, grouping_cols)] + else + #Actions are a vector of arrays (vectors if there is only one action, matrices if there are multiple) + actions = + [Array(agent_data[!, action_cols]) for agent_data in groupby(data, grouping_cols)] + end + + #Extract agent id's as combined symbols in a vector + agent_ids = [Symbol(join(string.(Tuple(row)), id_separator)) for row in eachrow(unique(data[!, grouping_cols]))] + + ## Determine whether any actions are missing ## + if actions isa Vector{A} where {R<:Real, A<:Array{Union{Missing, R}}} + #If there are missing actions + missing_actions = MissingActions() + elseif actions isa Vector{A} where {R<:Real, A<:Array{R}} + #If there are no missing actions + missing_actions = nothing + end #Create a full model combining the agent model and the statistical model - return full_model(agent_model, statistical_model, inputs, actions, track_states) + return full_model(agent_model, population_model, inputs, actions, agent_ids, missing_actions = missing_actions, check_parameter_rejections = check_parameter_rejections) end -################################################################### -### FUNCTION FOR DOING FULL AGENT AND STATISTICAL MODEL COMBINE ### -################################################################### +#################################################################### +### FUNCTION FOR DOING FULL AGENT AND STATISTICAL MODEL COMBINED ### +#################################################################### @model function full_model( agent::Agent, - statistical_model::DynamicPPL.Model, - inputs::Array{IA}, - actions::Array{AA}, - track_states::Bool = false, - multiple_inputs::Bool = size(first(inputs), 2) > 1, - multiple_actions::Bool = size(first(actions), 2) > 1, -) where {IAR<:Union{Real,Missing},AAR<:Union{Real,Missing},IA<:Array{IAR},AA<:Array{AAR}} - - #Check whether errors occur - try - - #Generate the agent parameters from the statistical model - @submodel statistical_model_return = statistical_model - - #Extract the agent parameters - agents_parameters = statistical_model_return.agent_parameters - - #If states are tracked - if track_states - #Initialize a vector for storing the states of the agents - agents_states = Vector{Dict}(undef, length(agents_parameters)) - parameters_per_agent = Vector{Dict}(undef, length(agents_parameters)) - else - agents_states = nothing - parameters_per_agent = nothing - end - - ## For each agent ## - for (agent_idx, agent_parameters) in enumerate(agents_parameters) - - #Set the agent parameters - set_parameters!(agent, agent_parameters) - reset!(agent) - - ## Construct input iterator ## - #If there is only one input - if !multiple_inputs - #Iterate over inputs one at a time - input_iterator = enumerate(inputs[agent_idx]) - else - #Iterate over rows of inputs - input_iterator = enumerate(Tuple.(eachrow(inputs[agent_idx]))) - end - - #Go through each timestep - for (timestep, input) in input_iterator - - ## Sample actions ## - - #Get the action probability distributions from the action model - action_distribution = agent.action_model(agent, input) - - #If there is only one action - if !multiple_actions - - #Sample the action from the probability distribution - @inbounds actions[agent_idx][timestep] ~ action_distribution - - #Save the action to the agent in case it needs it in the future - @inbounds update_states!( - agent, - "action", - ad_val.(actions[agent_idx][timestep]), - ) - - #If there are multiple actions - else - #Go through each separate action - for (action_idx, single_distribution) in enumerate(action_distribution) - - #Sample the action from the probability distribution - @inbounds actions[agent_idx][timestep, action_idx] ~ - single_distribution - end - - #Add the actions to the agent in case it needs it in the future - @inbounds update_states!( - agent, - "action", - ad_val.(actions[agent_idx][timestep, :]), - ) - end - end - - #If states are tracked - if track_states - #Save the parameters of the agent - parameters_per_agent[agent_idx] = get_parameters(agent) - #Save the history of tracked states for the agent - agents_states[agent_idx] = get_history(agent) - end - end - - #if states are tracked - if track_states - #Return agents' parameters and tracked states - return GeneratedQuantitites( - parameters_per_agent, - agents_states, - statistical_model_return.statistical_values, - ) - else - #Otherwise, return nothing - return nothing - end - - #If an error occurs - catch error - #If it is of the custom errortype RejectParameters - if error isa RejectParameters - #Make Turing reject the sample - Turing.@addlogprob!(-Inf) - else - #Otherwise, just throw the error - rethrow(error) - end - end -end + population_model::DynamicPPL.Model, + inputs_per_agent::Vector{I}, + actions_per_agent::Vector{A}, + agent_ids::Vector{Symbol}; + missing_actions::Union{Nothing, MissingActions} = MissingActions(), + check_parameter_rejections::Nothing = nothing, + actions_flattened::A = vcat(actions_per_agent...) +) where {I<:Vector, R<:Real, A1 <:Union{R,Union{Missing,R}}, A<:Array{A1}} + + #Generate the agent parameters from the statistical model + @submodel population_values = population_model + + #Generate the agent's behavior + @submodel agent_models(agent, agent_ids, population_values.agent_parameters, inputs_per_agent, actions_per_agent, actions_flattened, missing_actions) + + #Return values fron the population model (agent parameters and oher values) + return population_values +end \ No newline at end of file diff --git a/src/fitting/extract_quantities.jl b/src/fitting/extract_quantities.jl index a9d2816..8a4496b 100644 --- a/src/fitting/extract_quantities.jl +++ b/src/fitting/extract_quantities.jl @@ -5,11 +5,6 @@ function extract_quantities(fitted_model::Chains, model::DynamicPPL.Model) - #Check whether track_states = true - model.args.track_states || error( - "The passed model does not have track_changes = true. This is required for extracting agent states. Recreate the model with track_changes = true and repeat.", - ) - #Extract the generated quantities from the fitted model quantities = generated_quantities(model, fitted_model) @@ -17,17 +12,12 @@ function extract_quantities(fitted_model::Chains, model::DynamicPPL.Model) _quantities = first(quantities) n_agents = length(_quantities.agents_parameters) parameter_keys = keys(first(_quantities.agents_parameters)) - state_keys = keys(first(_quantities.agents_states)) #Create containers for the restructured values agent_parameters = [ Dict(parameter_key => Vector{Real}() for parameter_key in parameter_keys) for _ = 1:n_agents ] - agent_states = [ - Dict{Any,Array}(state_key => Vector() for state_key in state_keys) for - _ = 1:n_agents - ] statistical_values = Vector() #For each sample @@ -35,7 +25,6 @@ function extract_quantities(fitted_model::Chains, model::DynamicPPL.Model) #Unpack the sample sample_agent_parameters = sample.agents_parameters - sample_agent_states = sample.agents_states sample_statistical_values = sample.statistical_values #For each agent @@ -49,32 +38,11 @@ function extract_quantities(fitted_model::Chains, model::DynamicPPL.Model) sample_agent_parameters[agent_idx][parameter_key], ) end - - #For each state - for state_key in state_keys - #save the sampled state value - push!( - agent_states[agent_idx][state_key], - sample_agent_states[agent_idx][state_key], - ) - end end #Store the statistical value for the sample push!(statistical_values, sample_statistical_values) end - #For each agent - for agent_idx = 1:n_agents - #For each state - for state_key in state_keys - #Make the vector of vectors into a matrix - agent_states[agent_idx][state_key] = - transpose(reduce(hcat, agent_states[agent_idx][state_key])) - end - end - - #Give option for returning whole chains, CI, etc - - return (agent_parameters, agent_states, statistical_values) + return (agent_parameters, statistical_values) end diff --git a/src/fitting/fit_model.jl b/src/fitting/fit_model.jl index 2c98638..642cdb7 100644 --- a/src/fitting/fit_model.jl +++ b/src/fitting/fit_model.jl @@ -46,7 +46,7 @@ end #################################################################### # function fit_model( # agent::Agent, -# statistical_model::Union{M,P}, +# population_model::Union{M,P}, # data::DataFrame; # parallelization::Union{Nothing,AbstractMCMC.AbstractMCMCEnsemble} = nothing, # extract_quantities::Bool = true, @@ -54,14 +54,14 @@ end # ) where {M<:DynamicPPL.Model,T<:Union{String,Tuple,Any},D<:Distribution,P<:Dict{T,D}} # #Create a full model combining the agent model and the statistical model -# model = create_model(agent, statistical_model, data) +# model = create_model(agent, population_model, data) # #Fit the model # results = fit_model(model; parallelization = parallelization, sampler_kwargs...) # #Add tracked model # results.tracked_model = -# create_model(agent, statistical_model, data, track_states = true) +# create_model(agent, population_model, data, track_states = true) # #Extract tracked states # results.agent_parameters, results.agent_states, results.statistical_values = diff --git a/src/fitting/helper_functions.jl b/src/fitting/helper_functions.jl index 2c4f405..6a7f730 100644 --- a/src/fitting/helper_functions.jl +++ b/src/fitting/helper_functions.jl @@ -18,9 +18,9 @@ function ad_val(x::Real) end -############################################### -#### FUNCTION FOR CHECKING A CREATED MODEL #### -############################################### +############################################################ +#### FUNCTION FOR RENAMING THE CHAINS OF A FITTED MODEL #### +############################################################ function rename_chains( chains::Chains, model::DynamicPPL.Model, @@ -28,7 +28,7 @@ function rename_chains( grouping_cols::Union{Vector{C},C}, ) where {C<:Union{String,Symbol}} #This will multiple dispatch on the type of statistical model - rename_chains(chains, data, grouping_cols, model.args.statistical_model.args...) + rename_chains(chains, data, grouping_cols, model.args.population_model.args...) end @@ -37,16 +37,18 @@ end ############################################### function check_model( agent::Agent, - statistical_model::DynamicPPL.Model, + population_model::DynamicPPL.Model, data::DataFrame; input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T3}, grouping_cols::Union{Vector{T3},T3}, - track_states::Bool, verbose::Bool = true, ) where {T1<:Union{String,Symbol},T2<:Union{String,Symbol},T3<:Union{String,Symbol}} - #Run the check of the statistical model check_statistical_model(statistical_model.args...; verbose = verbose, agent = agent) + #TODO: Make check for whether the agent model outputs the right amount of actions / accepts the right amoiunts of inputs + + #Run the check of the statistical model + check_population_model(population_model.args...; verbose = verbose, agent = agent) #Check that user-specified columns exist in the dataset if any(grouping_cols .∉ Ref(Symbol.(names(data)))) diff --git a/src/fitting/parameter_recovery.jl b/src/fitting/parameter_recovery.jl index 4897c47..740ba8f 100644 --- a/src/fitting/parameter_recovery.jl +++ b/src/fitting/parameter_recovery.jl @@ -30,8 +30,6 @@ function single_recovery( #Fit the model to the simulated data result = fit_model(model; sampler_settings..., progress = false) - string(describe(result.chains)[2].nt.parameters[1]) - #Extract the posterior medians posterior_medians = get_posteriors(result.chains) diff --git a/src/fitting/simple_statistical_model.jl b/src/fitting/population_models/independent_agents_population_model.jl similarity index 95% rename from src/fitting/simple_statistical_model.jl rename to src/fitting/population_models/independent_agents_population_model.jl index 5ff6939..7d24953 100644 --- a/src/fitting/simple_statistical_model.jl +++ b/src/fitting/population_models/independent_agents_population_model.jl @@ -2,7 +2,7 @@ ####################################################################################################### ### SIMPLE STATISTICAL MODEL WHERE AGENTS ARE INDEPENDENT AND THEIR PARAMETERS HAVE THE SAME PRIORS ### ####################################################################################################### -@model function simple_statistical_model( +@model function independent_agents_population_model( prior::Dict{T,D}, n_agents::I, agent_parameters::Vector{Dict{Any,Real}} = [Dict{Any,Real}() for _ = 1:n_agents], @@ -40,8 +40,7 @@ function create_model( input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T2}, grouping_cols::Union{Vector{T3},T3} = Vector{String}(), - track_states::Bool = false, - verbose::Bool = true, + kwargs..., ) where { T<:Union{String,Tuple,Any}, D<:Distribution, @@ -54,18 +53,17 @@ function create_model( n_agents = length(groupby(data, grouping_cols)) #Create a statistical model where the agents are independent and sampled from the same prior - statistical_model = simple_statistical_model(prior, n_agents) + population_model = independent_agents_population_model(prior, n_agents) #Create a full model combining the agent model and the statistical model return create_model( agent, - statistical_model, + population_model, data; input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, - track_states = track_states, - verbose = verbose, + kwargs..., ) end @@ -83,6 +81,7 @@ function rename_chains( agent_parameters::Vector{Dict{Any,Real}}, ) where {T<:Union{String,Tuple,Any},D<:Distribution,C<:Union{String,Symbol},I<:Int} + #Make sure grouping columns are a vector if !(grouping_cols isa Vector{C}) grouping_cols = C[grouping_cols] end @@ -143,7 +142,7 @@ end ################################################################# ####### CHECKS TO BE MADE FOR THE SIMPLE STATISTICAL MODEL ###### ################################################################# -function check_statistical_model( +function check_population_model( #Arguments from statistical model prior::Dict{T,D}, n_agents::I, diff --git a/src/fitting/single_agent_statistical_model.jl b/src/fitting/population_models/single_agent_population_model.jl similarity index 93% rename from src/fitting/single_agent_statistical_model.jl rename to src/fitting/population_models/single_agent_population_model.jl index 922b39b..58cc91f 100644 --- a/src/fitting/single_agent_statistical_model.jl +++ b/src/fitting/population_models/single_agent_population_model.jl @@ -1,7 +1,7 @@ ########################################## ### STATISTICAL MODEL FOR SINGLE AGENT ### ########################################## -@model function single_statistical_model( +@model function single_agent_population_model( prior::Dict{T,D}, ) where {T<:Union{String,Tuple,Any},D<:Distribution} @@ -25,15 +25,13 @@ function create_model( prior::Dict{T,D}, inputs::Array{T1}, actions::Array{T2}; - track_states::Bool = false, - verbose::Bool = true, + kwargs..., ) where { T<:Union{String,Tuple,Any}, D<:Distribution, T1<:Union{Real,Missing}, T2<:Union{Real,Missing}, } - #Create column names input_cols = map(x -> "input$x", 1:size(inputs, 2)) action_cols = map(x -> "action$x", 1:size(actions, 2)) @@ -43,18 +41,17 @@ function create_model( data = DataFrame(hcat(inputs, actions), vcat(input_cols, action_cols)) #Create the single-agent statistical model - statistical_model = single_statistical_model(prior) + population_model = single_agent_population_model(prior) #Create a full model combining the agent model and the statistical model return create_model( agent, - statistical_model, + population_model, data; input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, - track_states = track_states, - verbose = verbose, + kwargs..., ) end @@ -112,7 +109,7 @@ end ############################################################ ### CHECKS TO MAKE FOR THE SINGLE-AGENT STAISTICAL MODEL ### ############################################################ -function check_statistical_model( +function check_population_model( #Arguments from statistical model prior::Dict{T,D}; #Arguments from the agent diff --git a/src/structs.jl b/src/structs.jl index 2cf49c0..1cf3166 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -12,7 +12,7 @@ Base.@kwdef mutable struct Agent save_history::Bool = true end -#TYPE FOR RETURNING THINGS FROM STATISTICAL MDOEL ### +#TYPE FOR RETURNING OUTCOMES OF STATISTICAL MDOEL ### struct StatisticalModelReturn agent_parameters::Vector{Dict} statistical_values::Any @@ -21,12 +21,6 @@ end StatisticalModelReturn(agent_parameters::Vector{D}) where {D<:Dict} = StatisticalModelReturn(agent_parameters, nothing) -#TYPE FOR RETURNING GENERATED QUANTITIES -struct GeneratedQuantitites - agents_parameters::Vector{Dict} - agents_states::Vector{Dict} - statistical_values::Union{Some{Any},Nothing} -end ### FOR THE GREATER FITMODEL FUNCTION mutable struct FitModelResults @@ -42,8 +36,11 @@ struct RejectParameters <: Exception errortext::Any end +struct CheckRejections +end - +struct MissingActions +end """ diff --git a/test/testsuite/create_model_tests.jl b/test/testsuite/create_model_tests.jl index 7971911..dfa7b69 100644 --- a/test/testsuite/create_model_tests.jl +++ b/test/testsuite/create_model_tests.jl @@ -65,27 +65,25 @@ using ActionModels, DataFrames renamed_model = rename_chains(fitted_model, model, data, :ID) #Create model with tracking states - model_tracked = create_model( - agent, - prior, - data, - input_cols = :inputs, - action_cols = :actions, - grouping_cols = :ID, - track_states = true, - ) + # model_tracked = create_model( + # agent, + # prior, + # data, + # input_cols = :inputs, + # action_cols = :actions, + # grouping_cols = :ID, + # ) #Extract quantities - agent_parameters, agent_states, statistical_values = - extract_quantities(fitted_model, model_tracked) + # agent_parameters, agent_states, statistical_values = + # extract_quantities(fitted_model, model_tracked) end - @testset "no grouping cols" begin + @testset "custom statistical model" begin end - @testset "multiple grouping cols" begin - + @testset "no grouping cols" begin #Create model model = create_model( agent, @@ -93,30 +91,46 @@ using ActionModels, DataFrames data, input_cols = :inputs, action_cols = :actions, - grouping_cols = [:ID, :category], + grouping_cols = Symbol[] ) #Fit model fitted_model = sample(model, sampler, n_iterations; n_chains = n_chains, sampling_kwargs...) + end - #Rename chains - renamed_model = rename_chains(fitted_model, model, data, [:ID, :category]) + @testset "multiple grouping cols" begin - #Create model with tracking states - model_tracked = create_model( + #Create model + model = create_model( agent, prior, data, input_cols = :inputs, action_cols = :actions, grouping_cols = [:ID, :category], - track_states = true, ) - #Extract quantities - agent_parameters, agent_states, statistical_values = - extract_quantities(fitted_model, model_tracked) + #Fit model + fitted_model = + sample(model, sampler, n_iterations; n_chains = n_chains, sampling_kwargs...) + + #Rename chains + renamed_model = rename_chains(fitted_model, model, data, [:ID, :category]) + + # #Create model with tracking states + # model_tracked = create_model( + # agent, + # prior, + # data, + # input_cols = :inputs, + # action_cols = :actions, + # grouping_cols = [:ID, :category], + # ) + + # #Extract quantities + # agent_parameters, agent_states, statistical_values = + # extract_quantities(fitted_model, model_tracked) end @testset "missing actions" begin @@ -142,24 +156,19 @@ using ActionModels, DataFrames #Rename chains renamed_model = rename_chains(fitted_model, model, data, :ID) - #Create model with tracking states - model_tracked = create_model( - agent, - prior, - data, - input_cols = :inputs, - action_cols = :actions, - grouping_cols = :ID, - track_states = true, - ) - - #Extract quantities - agent_parameters, agent_states, statistical_values = - extract_quantities(fitted_model, model_tracked) - end - - @testset "custom statistical model" begin - + # #Create model with tracking states + # model_tracked = create_model( + # agent, + # prior, + # data, + # input_cols = :inputs, + # action_cols = :actions, + # grouping_cols = :ID, + # ) + + # #Extract quantities + # agent_parameters, agent_states, statistical_values = + # extract_quantities(fitted_model, model_tracked) end @testset "multiple actions" begin @@ -193,6 +202,11 @@ using ActionModels, DataFrames fitted_model = sample(model, sampler, n_iterations; n_chains = n_chains, sampling_kwargs...) + end + + @testset "multiple actions, missing actions" begin + + end @testset "multiple inputs" begin