From 483da0560b44fa985d0a8bd17086ef2eb7d8c5fc Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Tue, 1 Oct 2024 10:56:31 +0200 Subject: [PATCH 1/7] removed trascking of agent states within full_model (will make separate function for this) --- src/fitting/create_model.jl | 53 +++++-------------- src/fitting/extract_quantities.jl | 34 +----------- src/fitting/helper_functions.jl | 1 - src/fitting/simple_statistical_model.jl | 2 - src/fitting/single_agent_statistical_model.jl | 2 - src/structs.jl | 1 - test/testsuite/create_model_tests.jl | 3 -- 7 files changed, 13 insertions(+), 83 deletions(-) diff --git a/src/fitting/create_model.jl b/src/fitting/create_model.jl index 2837ef8..5348b26 100644 --- a/src/fitting/create_model.jl +++ b/src/fitting/create_model.jl @@ -8,7 +8,6 @@ function create_model( input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T3}, grouping_cols::Union{Vector{T3},T3} = Vector{String}(), - track_states::Bool = false, verbose::Bool = true, ) where {T1<:Union{String,Symbol},T2<:Union{String,Symbol},T3<:Union{String,Symbol}} @@ -16,14 +15,8 @@ function create_model( #Create a copy of the agent to avoid changing the original agent_model = deepcopy(agent) - #If states are to be tracked - if track_states - #Make sure the agent saves the history - set_save_history!(agent_model, true) - else - #Otherwise not - set_save_history!(agent_model, false) - end + #Turn off saving the history of states + set_save_history!(agent_model, false) ## Make sure columns are vectors of symbols ## if !(input_cols isa Vector) @@ -47,7 +40,6 @@ function create_model( input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, - track_states = track_states, verbose = verbose, ) @@ -59,7 +51,7 @@ function create_model( [Array(agent_data[:, action_cols]) for agent_data in groupby(data, grouping_cols)] #Create a full model combining the agent model and the statistical model - return full_model(agent_model, statistical_model, inputs, actions, track_states) + return full_model(agent_model, statistical_model, inputs, actions) end ################################################################### @@ -70,7 +62,6 @@ end statistical_model::DynamicPPL.Model, inputs::Array{IA}, actions::Array{AA}, - track_states::Bool = false, multiple_inputs::Bool = size(first(inputs), 2) > 1, multiple_actions::Bool = size(first(actions), 2) > 1, ) where {IAR<:Union{Real,Missing},AAR<:Union{Real,Missing},IA<:Array{IAR},AA<:Array{AAR}} @@ -84,15 +75,8 @@ end #Extract the agent parameters agents_parameters = statistical_model_return.agent_parameters - #If states are tracked - if track_states - #Initialize a vector for storing the states of the agents - agents_states = Vector{Dict}(undef, length(agents_parameters)) - parameters_per_agent = Vector{Dict}(undef, length(agents_parameters)) - else - agents_states = nothing - parameters_per_agent = nothing - end + #Initialize a vector for storing the states of the agents + parameters_per_agent = Vector{Dict}(undef, length(agents_parameters)) ## For each agent ## for (agent_idx, agent_parameters) in enumerate(agents_parameters) @@ -150,28 +134,15 @@ end ) end end - - #If states are tracked - if track_states - #Save the parameters of the agent - parameters_per_agent[agent_idx] = get_parameters(agent) - #Save the history of tracked states for the agent - agents_states[agent_idx] = get_history(agent) - end + #Save the parameters of the agent + parameters_per_agent[agent_idx] = get_parameters(agent) end - #if states are tracked - if track_states - #Return agents' parameters and tracked states - return GeneratedQuantitites( - parameters_per_agent, - agents_states, - statistical_model_return.statistical_values, - ) - else - #Otherwise, return nothing - return nothing - end + #Return agents' parameters and tracked states + return GeneratedQuantitites( + parameters_per_agent, + statistical_model_return.statistical_values, + ) #If an error occurs catch error diff --git a/src/fitting/extract_quantities.jl b/src/fitting/extract_quantities.jl index a9d2816..8a4496b 100644 --- a/src/fitting/extract_quantities.jl +++ b/src/fitting/extract_quantities.jl @@ -5,11 +5,6 @@ function extract_quantities(fitted_model::Chains, model::DynamicPPL.Model) - #Check whether track_states = true - model.args.track_states || error( - "The passed model does not have track_changes = true. This is required for extracting agent states. Recreate the model with track_changes = true and repeat.", - ) - #Extract the generated quantities from the fitted model quantities = generated_quantities(model, fitted_model) @@ -17,17 +12,12 @@ function extract_quantities(fitted_model::Chains, model::DynamicPPL.Model) _quantities = first(quantities) n_agents = length(_quantities.agents_parameters) parameter_keys = keys(first(_quantities.agents_parameters)) - state_keys = keys(first(_quantities.agents_states)) #Create containers for the restructured values agent_parameters = [ Dict(parameter_key => Vector{Real}() for parameter_key in parameter_keys) for _ = 1:n_agents ] - agent_states = [ - Dict{Any,Array}(state_key => Vector() for state_key in state_keys) for - _ = 1:n_agents - ] statistical_values = Vector() #For each sample @@ -35,7 +25,6 @@ function extract_quantities(fitted_model::Chains, model::DynamicPPL.Model) #Unpack the sample sample_agent_parameters = sample.agents_parameters - sample_agent_states = sample.agents_states sample_statistical_values = sample.statistical_values #For each agent @@ -49,32 +38,11 @@ function extract_quantities(fitted_model::Chains, model::DynamicPPL.Model) sample_agent_parameters[agent_idx][parameter_key], ) end - - #For each state - for state_key in state_keys - #save the sampled state value - push!( - agent_states[agent_idx][state_key], - sample_agent_states[agent_idx][state_key], - ) - end end #Store the statistical value for the sample push!(statistical_values, sample_statistical_values) end - #For each agent - for agent_idx = 1:n_agents - #For each state - for state_key in state_keys - #Make the vector of vectors into a matrix - agent_states[agent_idx][state_key] = - transpose(reduce(hcat, agent_states[agent_idx][state_key])) - end - end - - #Give option for returning whole chains, CI, etc - - return (agent_parameters, agent_states, statistical_values) + return (agent_parameters, statistical_values) end diff --git a/src/fitting/helper_functions.jl b/src/fitting/helper_functions.jl index 2c4f405..59a39be 100644 --- a/src/fitting/helper_functions.jl +++ b/src/fitting/helper_functions.jl @@ -42,7 +42,6 @@ function check_model( input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T3}, grouping_cols::Union{Vector{T3},T3}, - track_states::Bool, verbose::Bool = true, ) where {T1<:Union{String,Symbol},T2<:Union{String,Symbol},T3<:Union{String,Symbol}} diff --git a/src/fitting/simple_statistical_model.jl b/src/fitting/simple_statistical_model.jl index 5ff6939..b42bbe1 100644 --- a/src/fitting/simple_statistical_model.jl +++ b/src/fitting/simple_statistical_model.jl @@ -40,7 +40,6 @@ function create_model( input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T2}, grouping_cols::Union{Vector{T3},T3} = Vector{String}(), - track_states::Bool = false, verbose::Bool = true, ) where { T<:Union{String,Tuple,Any}, @@ -64,7 +63,6 @@ function create_model( input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, - track_states = track_states, verbose = verbose, ) end diff --git a/src/fitting/single_agent_statistical_model.jl b/src/fitting/single_agent_statistical_model.jl index 922b39b..fdc5240 100644 --- a/src/fitting/single_agent_statistical_model.jl +++ b/src/fitting/single_agent_statistical_model.jl @@ -25,7 +25,6 @@ function create_model( prior::Dict{T,D}, inputs::Array{T1}, actions::Array{T2}; - track_states::Bool = false, verbose::Bool = true, ) where { T<:Union{String,Tuple,Any}, @@ -53,7 +52,6 @@ function create_model( input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, - track_states = track_states, verbose = verbose, ) end diff --git a/src/structs.jl b/src/structs.jl index 2cf49c0..fdf2d53 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -24,7 +24,6 @@ StatisticalModelReturn(agent_parameters::Vector{D}) where {D<:Dict} = #TYPE FOR RETURNING GENERATED QUANTITIES struct GeneratedQuantitites agents_parameters::Vector{Dict} - agents_states::Vector{Dict} statistical_values::Union{Some{Any},Nothing} end diff --git a/test/testsuite/create_model_tests.jl b/test/testsuite/create_model_tests.jl index 7971911..cffd3fa 100644 --- a/test/testsuite/create_model_tests.jl +++ b/test/testsuite/create_model_tests.jl @@ -72,7 +72,6 @@ using ActionModels, DataFrames input_cols = :inputs, action_cols = :actions, grouping_cols = :ID, - track_states = true, ) #Extract quantities @@ -111,7 +110,6 @@ using ActionModels, DataFrames input_cols = :inputs, action_cols = :actions, grouping_cols = [:ID, :category], - track_states = true, ) #Extract quantities @@ -150,7 +148,6 @@ using ActionModels, DataFrames input_cols = :inputs, action_cols = :actions, grouping_cols = :ID, - track_states = true, ) #Extract quantities From 2c7c3d496a164f21807d1ea56d05f6fd09c6bc5e Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Tue, 1 Oct 2024 11:26:47 +0200 Subject: [PATCH 2/7] minor bugfix --- src/fitting/parameter_recovery.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/fitting/parameter_recovery.jl b/src/fitting/parameter_recovery.jl index 4897c47..740ba8f 100644 --- a/src/fitting/parameter_recovery.jl +++ b/src/fitting/parameter_recovery.jl @@ -30,8 +30,6 @@ function single_recovery( #Fit the model to the simulated data result = fit_model(model; sampler_settings..., progress = false) - string(describe(result.chains)[2].nt.parameters[1]) - #Extract the posterior medians posterior_medians = get_posteriors(result.chains) From 9a5f8ff9912ae118820d761e307b5afc22a93d19 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Tue, 1 Oct 2024 17:25:55 +0200 Subject: [PATCH 3/7] fundamental restructuring now complete - needs to bugfix. Also sorted api a bit to create model --- src/ActionModels.jl | 5 +- src/fitting/agent_model.jl | 123 ++++++++++++ src/fitting/create_model.jl | 176 ++++++++---------- src/fitting/helper_functions.jl | 11 +- src/fitting/simple_statistical_model.jl | 5 +- src/fitting/single_agent_statistical_model.jl | 5 +- src/structs.jl | 10 +- 7 files changed, 215 insertions(+), 120 deletions(-) create mode 100644 src/fitting/agent_model.jl diff --git a/src/ActionModels.jl b/src/ActionModels.jl index 137481c..3930133 100644 --- a/src/ActionModels.jl +++ b/src/ActionModels.jl @@ -44,13 +44,14 @@ include("create_agent/create_premade_agent.jl") include("create_agent/multiple_actions.jl") include("create_agent/check_agent.jl") #Functions for fitting agents to data -include("fitting/helper_functions.jl") -include("fitting/extract_quantities.jl") include("fitting/create_model.jl") +include("fitting/agent_model.jl") include("fitting/simple_statistical_model.jl") include("fitting/single_agent_statistical_model.jl") include("fitting/fit_model.jl") include("fitting/parameter_recovery.jl") +include("fitting/helper_functions.jl") +include("fitting/extract_quantities.jl") #include("fitting/prefit_checks.jl") #Plotting functions for agents diff --git a/src/fitting/agent_model.jl b/src/fitting/agent_model.jl new file mode 100644 index 0000000..3c9c143 --- /dev/null +++ b/src/fitting/agent_model.jl @@ -0,0 +1,123 @@ +############################################### +### WITH SINGLE ACTION / NO MISSING ACTIONS ### +############################################### +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Vector{R}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Initialize a vector for storing the action probability distributions + action_distributions = Vector(undef, length(inputs)) + + #Go through each timestep + for (timestep, (input, action)) in enumerate(zip(inputs, actions)) + + #Get the action probability distributions from the action model + @inbounds action_distributions[timestep] = agent.action_model(agent, input) + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new vector) + action_distributions = [dist for dist in action_distributions] + + #Sample the actions from the probability distributions + actions ~ arraydist(action_distributions) +end + +################################################## +### WITH MULTIPLE ACTIONS / NO MISSING ACTIONS ### +################################################## + +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Matrix{R}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Initialize a matrix for storing the action probability distributions + action_distributions = Matrix(undef, size(actions)) + + #Go through each timestep + for (timestep, (input, action)) in enumerate(zip(inputs, Tuple.(eachrow(actions)))) + + #Get the action probability distributions from the action model + action_distributions[timestep, :] = agent.action_model(agent, input) #TODO: can use @inbounds here when there's a check for whether the right amount of actions are used + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new matrix) + action_distributions = [dist for dist in action_distributions] + + #Sample the actions from the probability distributions + actions ~ arraydist(action_distributions) +end + +################################################# +### WITH SINGLE ACTION / WITH MISSING ACTIONS ### +################################################# + +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Vector{Union{Missing, R}}) where {D<:Dict, I<:Vector, R<:Real} + + @show "yep" + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Go through each timestep + for (timestep, input) in enumerate(inputs) + + #Get the action probability distributions from the action model + action_distribution = agent.action_model(agent, input) + + #Sample the action from the probability distribution + @inbounds actions[timestep] ~ action_distribution + + #Save the action to the agent in case it needs it in the future + @inbounds update_states!( + agent, + "action", + ad_val(actions[timestep]), + ) + end +end + + +#################################################### +### WITH MULTIPLE ACTIONS / WITH MISSING ACTIONS ### +#################################################### +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Matrix{Union{Missing, R}}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Go through each timestep + for (timestep, input) in enumerate(inputs) + + #Get the action probability distributions from the action model + action_distributions = agent.action_model(agent, input) + + #Go through each action + for (action_idx, single_distribution) in enumerate(action_distributions) + + #Sample the action from the probability distribution + actions[timestep, action_idx] ~ + single_distribution + #TODO: can use @inbounds here when there's a check for whether the right amount of actions are produced + #TODO: Could use arraydist here if this was formatted as a vector of vectors (probably not!) + end + + #Add the actions to the agent in case it needs it in the future + update_states!( + agent, + "action", + ad_val.(actions[timestep, :]), + ) + #TODO: can use @inbounds here when there's a check for whether the right amount of actions are produced + end +end \ No newline at end of file diff --git a/src/fitting/create_model.jl b/src/fitting/create_model.jl index 5348b26..461d825 100644 --- a/src/fitting/create_model.jl +++ b/src/fitting/create_model.jl @@ -7,7 +7,9 @@ function create_model( data::DataFrame; input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T3}, - grouping_cols::Union{Vector{T3},T3} = Vector{String}(), + grouping_cols::Union{Vector{T3},T3} = Vector{Symbol}(), + check_parameter_rejections::Union{Nothing, CheckRejections} = nothing, + id_separator::String = "__", verbose::Bool = true, ) where {T1<:Union{String,Symbol},T2<:Union{String,Symbol},T3<:Union{String,Symbol}} @@ -22,14 +24,16 @@ function create_model( if !(input_cols isa Vector) input_cols = [input_cols] end + input_cols = Symbol.(input_cols) + if !(action_cols isa Vector) action_cols = [action_cols] end + action_cols = Symbol.(action_cols) + if !(grouping_cols isa Vector) grouping_cols = [grouping_cols] end - input_cols = Symbol.(input_cols) - action_cols = Symbol.(action_cols) grouping_cols = Symbol.(grouping_cols) #Run checks for the model specifications @@ -45,114 +49,82 @@ function create_model( ## Extract data ## #One matrix per agent, for inputs and actions separately + #FIXME: put inputs into correct format inputs = [Array(agent_data[:, input_cols]) for agent_data in groupby(data, grouping_cols)] actions = [Array(agent_data[:, action_cols]) for agent_data in groupby(data, grouping_cols)] + #Extract agent id's as combined symbols in a vector + agent_ids = [Symbol(join(string.(Tuple(row)), id_separator)) for row in eachrow(unique(data[!, grouping_cols]))] + #Create a full model combining the agent model and the statistical model - return full_model(agent_model, statistical_model, inputs, actions) + return full_model(agent_model, statistical_model, inputs, actions, agent_ids, check_parameter_rejections = check_parameter_rejections) end -################################################################### -### FUNCTION FOR DOING FULL AGENT AND STATISTICAL MODEL COMBINE ### -################################################################### +#################################################################### +### FUNCTION FOR DOING FULL AGENT AND STATISTICAL MODEL COMBINED ### +#################################################################### @model function full_model( agent::Agent, - statistical_model::DynamicPPL.Model, - inputs::Array{IA}, - actions::Array{AA}, - multiple_inputs::Bool = size(first(inputs), 2) > 1, - multiple_actions::Bool = size(first(actions), 2) > 1, -) where {IAR<:Union{Real,Missing},AAR<:Union{Real,Missing},IA<:Array{IAR},AA<:Array{AAR}} - - #Check whether errors occur - try - - #Generate the agent parameters from the statistical model - @submodel statistical_model_return = statistical_model - - #Extract the agent parameters - agents_parameters = statistical_model_return.agent_parameters - - #Initialize a vector for storing the states of the agents - parameters_per_agent = Vector{Dict}(undef, length(agents_parameters)) - - ## For each agent ## - for (agent_idx, agent_parameters) in enumerate(agents_parameters) - - #Set the agent parameters - set_parameters!(agent, agent_parameters) - reset!(agent) - - ## Construct input iterator ## - #If there is only one input - if !multiple_inputs - #Iterate over inputs one at a time - input_iterator = enumerate(inputs[agent_idx]) - else - #Iterate over rows of inputs - input_iterator = enumerate(Tuple.(eachrow(inputs[agent_idx]))) - end - - #Go through each timestep - for (timestep, input) in input_iterator - - ## Sample actions ## - - #Get the action probability distributions from the action model - action_distribution = agent.action_model(agent, input) - - #If there is only one action - if !multiple_actions - - #Sample the action from the probability distribution - @inbounds actions[agent_idx][timestep] ~ action_distribution - - #Save the action to the agent in case it needs it in the future - @inbounds update_states!( - agent, - "action", - ad_val.(actions[agent_idx][timestep]), - ) - - #If there are multiple actions - else - #Go through each separate action - for (action_idx, single_distribution) in enumerate(action_distribution) - - #Sample the action from the probability distribution - @inbounds actions[agent_idx][timestep, action_idx] ~ - single_distribution - end - - #Add the actions to the agent in case it needs it in the future - @inbounds update_states!( - agent, - "action", - ad_val.(actions[agent_idx][timestep, :]), - ) - end - end - #Save the parameters of the agent - parameters_per_agent[agent_idx] = get_parameters(agent) - end - - #Return agents' parameters and tracked states - return GeneratedQuantitites( - parameters_per_agent, - statistical_model_return.statistical_values, - ) - - #If an error occurs - catch error - #If it is of the custom errortype RejectParameters - if error isa RejectParameters - #Make Turing reject the sample - Turing.@addlogprob!(-Inf) - else - #Otherwise, just throw the error - rethrow(error) - end + population_model::DynamicPPL.Model, + inputs_per_agent::Vector{I}, + actions_per_agent::Vector{A}, + agent_ids::Vector{Symbol}; + check_parameter_rejections::Nothing = nothing, +) where {I<:Vector, R<:Real, A1 <:Union{R,Union{Missing,R}}, A<:Array{A1}} + + #Generate the agent parameters from the statistical model + @submodel population_values = population_model + + ## For each agent ## + for (agent_id, agent_parameters, agent_inputs, agent_actions) in zip(agent_ids, population_values.agent_parameters, inputs_per_agent, actions_per_agent) + + @submodel prefix = agent_id agent_model(agent, agent_parameters, agent_inputs, agent_actions) end + + #Return agents' parameters and tracked states + return statistical_model_return end + + +############################################################################# +### WRAPPER FUNCTION FOR FULL_MODEL FOR CHECKING FOR PARAMETER REJECTIONS ### +############################################################################# + +# @model function full_model( +# agent::Agent, +# statistical_model::DynamicPPL.Model, +# inputs_per_agent::Array{IA}, +# actions_per_agent::Array{AA}; +# agent_ids::Vector{Union{Symbol,Vector{Symbol}}}, +# check_parameter_rejections::CheckRejections, +# ) where {IAR<:Union{Real,Missing},AAR<:Union{Real,Missing},IA<:Array{IAR},AA<:Array{AAR}} + +# #Check whether errors occur +# try + +# #Run the full model +# @submodel generated_quantities = full_model( +# agent, +# statistical_model, +# inputs_per_agent, +# actions_per_agent; +# agent_ids = agent_ids, +# check_parameter_rejections = nothing, +# ) + +# return generated_quantities + +# #If an error occurs +# catch error +# #If it is of the custom errortype RejectParameters +# if error isa RejectParameters +# #Make Turing reject the sample +# Turing.@addlogprob!(-Inf) +# else +# #Otherwise, just throw the error +# rethrow(error) +# end +# end +# end \ No newline at end of file diff --git a/src/fitting/helper_functions.jl b/src/fitting/helper_functions.jl index 59a39be..f98ec21 100644 --- a/src/fitting/helper_functions.jl +++ b/src/fitting/helper_functions.jl @@ -18,9 +18,9 @@ function ad_val(x::Real) end -############################################### -#### FUNCTION FOR CHECKING A CREATED MODEL #### -############################################### +############################################################ +#### FUNCTION FOR RENAMING THE CHAINS OF A FITTED MODEL #### +############################################################ function rename_chains( chains::Chains, model::DynamicPPL.Model, @@ -45,7 +45,10 @@ function check_model( verbose::Bool = true, ) where {T1<:Union{String,Symbol},T2<:Union{String,Symbol},T3<:Union{String,Symbol}} - #Run the check of the statistical model check_statistical_model(statistical_model.args...; verbose = verbose, agent = agent) + #TODO: Make check for whether the agent model outputs the right amount of actions / accepts the right amoiunts of inputs + + #Run the check of the statistical model + check_statistical_model(statistical_model.args...; verbose = verbose, agent = agent) #Check that user-specified columns exist in the dataset if any(grouping_cols .∉ Ref(Symbol.(names(data)))) diff --git a/src/fitting/simple_statistical_model.jl b/src/fitting/simple_statistical_model.jl index b42bbe1..aa13a3f 100644 --- a/src/fitting/simple_statistical_model.jl +++ b/src/fitting/simple_statistical_model.jl @@ -40,7 +40,7 @@ function create_model( input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T2}, grouping_cols::Union{Vector{T3},T3} = Vector{String}(), - verbose::Bool = true, + kwargs..., ) where { T<:Union{String,Tuple,Any}, D<:Distribution, @@ -63,7 +63,7 @@ function create_model( input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, - verbose = verbose, + kwargs..., ) end @@ -81,6 +81,7 @@ function rename_chains( agent_parameters::Vector{Dict{Any,Real}}, ) where {T<:Union{String,Tuple,Any},D<:Distribution,C<:Union{String,Symbol},I<:Int} + #Make sure grouping columns are a vector if !(grouping_cols isa Vector{C}) grouping_cols = C[grouping_cols] end diff --git a/src/fitting/single_agent_statistical_model.jl b/src/fitting/single_agent_statistical_model.jl index fdc5240..42ad89a 100644 --- a/src/fitting/single_agent_statistical_model.jl +++ b/src/fitting/single_agent_statistical_model.jl @@ -25,14 +25,13 @@ function create_model( prior::Dict{T,D}, inputs::Array{T1}, actions::Array{T2}; - verbose::Bool = true, + kwargs..., ) where { T<:Union{String,Tuple,Any}, D<:Distribution, T1<:Union{Real,Missing}, T2<:Union{Real,Missing}, } - #Create column names input_cols = map(x -> "input$x", 1:size(inputs, 2)) action_cols = map(x -> "action$x", 1:size(actions, 2)) @@ -52,7 +51,7 @@ function create_model( input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, - verbose = verbose, + kwargs..., ) end diff --git a/src/structs.jl b/src/structs.jl index fdf2d53..d88250d 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -12,7 +12,7 @@ Base.@kwdef mutable struct Agent save_history::Bool = true end -#TYPE FOR RETURNING THINGS FROM STATISTICAL MDOEL ### +#TYPE FOR RETURNING OUTCOMES OF STATISTICAL MDOEL ### struct StatisticalModelReturn agent_parameters::Vector{Dict} statistical_values::Any @@ -21,11 +21,6 @@ end StatisticalModelReturn(agent_parameters::Vector{D}) where {D<:Dict} = StatisticalModelReturn(agent_parameters, nothing) -#TYPE FOR RETURNING GENERATED QUANTITIES -struct GeneratedQuantitites - agents_parameters::Vector{Dict} - statistical_values::Union{Some{Any},Nothing} -end ### FOR THE GREATER FITMODEL FUNCTION mutable struct FitModelResults @@ -41,7 +36,8 @@ struct RejectParameters <: Exception errortext::Any end - +struct CheckRejections +end From 1561d480028e51030be16264633c458c033e3dcf Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Wed, 2 Oct 2024 17:47:07 +0200 Subject: [PATCH 4/7] statistical_model -> population_model & agent_model now functional --- src/ActionModels.jl | 6 +- src/fitting/agent_model.jl | 2 +- src/fitting/create_model.jl | 42 ++++++---- src/fitting/fit_model.jl | 6 +- src/fitting/helper_functions.jl | 6 +- .../independent_agents_population_model.jl} | 8 +- .../single_agent_population_model.jl} | 8 +- test/testsuite/create_model_tests.jl | 77 ++++++++++--------- 8 files changed, 87 insertions(+), 68 deletions(-) rename src/fitting/{simple_statistical_model.jl => population_models/independent_agents_population_model.jl} (96%) rename src/fitting/{single_agent_statistical_model.jl => population_models/single_agent_population_model.jl} (96%) diff --git a/src/ActionModels.jl b/src/ActionModels.jl index 3930133..2531577 100644 --- a/src/ActionModels.jl +++ b/src/ActionModels.jl @@ -10,7 +10,7 @@ using ProgressMeter, Distributed #TODO: get rid of this (only needed for paramet #Export functions export Agent, RejectParameters, InitialStateParameter, ParameterGroup export init_agent, premade_agent, warn_premade_defaults, multiple_actions, check_agent -export simple_statistical_model, +export independent_agents_population_model, create_model, fit_model, parameter_recovery, single_recovery export plot_parameter_distribution, plot_predictive_simulation, plot_trajectory, plot_trajectory! @@ -46,8 +46,8 @@ include("create_agent/check_agent.jl") #Functions for fitting agents to data include("fitting/create_model.jl") include("fitting/agent_model.jl") -include("fitting/simple_statistical_model.jl") -include("fitting/single_agent_statistical_model.jl") +include("fitting/population_models/independent_agents_population_model.jl") +include("fitting/population_models/single_agent_population_model.jl") include("fitting/fit_model.jl") include("fitting/parameter_recovery.jl") include("fitting/helper_functions.jl") diff --git a/src/fitting/agent_model.jl b/src/fitting/agent_model.jl index 3c9c143..6833771 100644 --- a/src/fitting/agent_model.jl +++ b/src/fitting/agent_model.jl @@ -38,7 +38,7 @@ end reset!(agent) #Initialize a matrix for storing the action probability distributions - action_distributions = Matrix(undef, size(actions)) + action_distributions = Matrix(undef, size(actions)...) #Go through each timestep for (timestep, (input, action)) in enumerate(zip(inputs, Tuple.(eachrow(actions)))) diff --git a/src/fitting/create_model.jl b/src/fitting/create_model.jl index 461d825..ce90e4f 100644 --- a/src/fitting/create_model.jl +++ b/src/fitting/create_model.jl @@ -3,7 +3,7 @@ ########################################################################################################### function create_model( agent::Agent, - statistical_model::DynamicPPL.Model, + population_model::DynamicPPL.Model, data::DataFrame; input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T3}, @@ -39,27 +39,41 @@ function create_model( #Run checks for the model specifications check_model( agent, - statistical_model, + population_model, data; input_cols = input_cols, action_cols = action_cols, grouping_cols = grouping_cols, verbose = verbose, ) - + + ## Extract data ## - #One matrix per agent, for inputs and actions separately - #FIXME: put inputs into correct format - inputs = - [Array(agent_data[:, input_cols]) for agent_data in groupby(data, grouping_cols)] - actions = - [Array(agent_data[:, action_cols]) for agent_data in groupby(data, grouping_cols)] + #If there is only one input column + if length(input_cols) == 1 + #Inputs are a vector of vectors of <:reals + inputs = [Vector(agent_data[!,first(input_cols)]) for agent_data in groupby(data, grouping_cols)] + else + #Otherwise, they are a vector of vectors of tuples + inputs = [Tuple.(eachrow(agent_data[!,input_cols])) for agent_data in groupby(data, grouping_cols)] + end + + #If there is only one action column + if length(action_cols) == 1 + #Actions are a vector of arrays (vectors if there is only one action, matrices if there are multiple) + actions = + [Vector(agent_data[!, first(action_cols)]) for agent_data in groupby(data, grouping_cols)] + else + #Actions are a vector of arrays (vectors if there is only one action, matrices if there are multiple) + actions = + [Array(agent_data[!, action_cols]) for agent_data in groupby(data, grouping_cols)] + end #Extract agent id's as combined symbols in a vector agent_ids = [Symbol(join(string.(Tuple(row)), id_separator)) for row in eachrow(unique(data[!, grouping_cols]))] #Create a full model combining the agent model and the statistical model - return full_model(agent_model, statistical_model, inputs, actions, agent_ids, check_parameter_rejections = check_parameter_rejections) + return full_model(agent_model, population_model, inputs, actions, agent_ids, check_parameter_rejections = check_parameter_rejections) end #################################################################### @@ -80,11 +94,11 @@ end ## For each agent ## for (agent_id, agent_parameters, agent_inputs, agent_actions) in zip(agent_ids, population_values.agent_parameters, inputs_per_agent, actions_per_agent) - @submodel prefix = agent_id agent_model(agent, agent_parameters, agent_inputs, agent_actions) + @submodel prefix = "$agent_id" agent_model(agent, agent_parameters, agent_inputs, agent_actions) end #Return agents' parameters and tracked states - return statistical_model_return + return population_values end @@ -94,7 +108,7 @@ end # @model function full_model( # agent::Agent, -# statistical_model::DynamicPPL.Model, +# population_model::DynamicPPL.Model, # inputs_per_agent::Array{IA}, # actions_per_agent::Array{AA}; # agent_ids::Vector{Union{Symbol,Vector{Symbol}}}, @@ -107,7 +121,7 @@ end # #Run the full model # @submodel generated_quantities = full_model( # agent, -# statistical_model, +# population_model, # inputs_per_agent, # actions_per_agent; # agent_ids = agent_ids, diff --git a/src/fitting/fit_model.jl b/src/fitting/fit_model.jl index 2c98638..642cdb7 100644 --- a/src/fitting/fit_model.jl +++ b/src/fitting/fit_model.jl @@ -46,7 +46,7 @@ end #################################################################### # function fit_model( # agent::Agent, -# statistical_model::Union{M,P}, +# population_model::Union{M,P}, # data::DataFrame; # parallelization::Union{Nothing,AbstractMCMC.AbstractMCMCEnsemble} = nothing, # extract_quantities::Bool = true, @@ -54,14 +54,14 @@ end # ) where {M<:DynamicPPL.Model,T<:Union{String,Tuple,Any},D<:Distribution,P<:Dict{T,D}} # #Create a full model combining the agent model and the statistical model -# model = create_model(agent, statistical_model, data) +# model = create_model(agent, population_model, data) # #Fit the model # results = fit_model(model; parallelization = parallelization, sampler_kwargs...) # #Add tracked model # results.tracked_model = -# create_model(agent, statistical_model, data, track_states = true) +# create_model(agent, population_model, data, track_states = true) # #Extract tracked states # results.agent_parameters, results.agent_states, results.statistical_values = diff --git a/src/fitting/helper_functions.jl b/src/fitting/helper_functions.jl index f98ec21..6a7f730 100644 --- a/src/fitting/helper_functions.jl +++ b/src/fitting/helper_functions.jl @@ -28,7 +28,7 @@ function rename_chains( grouping_cols::Union{Vector{C},C}, ) where {C<:Union{String,Symbol}} #This will multiple dispatch on the type of statistical model - rename_chains(chains, data, grouping_cols, model.args.statistical_model.args...) + rename_chains(chains, data, grouping_cols, model.args.population_model.args...) end @@ -37,7 +37,7 @@ end ############################################### function check_model( agent::Agent, - statistical_model::DynamicPPL.Model, + population_model::DynamicPPL.Model, data::DataFrame; input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T3}, @@ -48,7 +48,7 @@ function check_model( #TODO: Make check for whether the agent model outputs the right amount of actions / accepts the right amoiunts of inputs #Run the check of the statistical model - check_statistical_model(statistical_model.args...; verbose = verbose, agent = agent) + check_population_model(population_model.args...; verbose = verbose, agent = agent) #Check that user-specified columns exist in the dataset if any(grouping_cols .∉ Ref(Symbol.(names(data)))) diff --git a/src/fitting/simple_statistical_model.jl b/src/fitting/population_models/independent_agents_population_model.jl similarity index 96% rename from src/fitting/simple_statistical_model.jl rename to src/fitting/population_models/independent_agents_population_model.jl index aa13a3f..7d24953 100644 --- a/src/fitting/simple_statistical_model.jl +++ b/src/fitting/population_models/independent_agents_population_model.jl @@ -2,7 +2,7 @@ ####################################################################################################### ### SIMPLE STATISTICAL MODEL WHERE AGENTS ARE INDEPENDENT AND THEIR PARAMETERS HAVE THE SAME PRIORS ### ####################################################################################################### -@model function simple_statistical_model( +@model function independent_agents_population_model( prior::Dict{T,D}, n_agents::I, agent_parameters::Vector{Dict{Any,Real}} = [Dict{Any,Real}() for _ = 1:n_agents], @@ -53,12 +53,12 @@ function create_model( n_agents = length(groupby(data, grouping_cols)) #Create a statistical model where the agents are independent and sampled from the same prior - statistical_model = simple_statistical_model(prior, n_agents) + population_model = independent_agents_population_model(prior, n_agents) #Create a full model combining the agent model and the statistical model return create_model( agent, - statistical_model, + population_model, data; input_cols = input_cols, action_cols = action_cols, @@ -142,7 +142,7 @@ end ################################################################# ####### CHECKS TO BE MADE FOR THE SIMPLE STATISTICAL MODEL ###### ################################################################# -function check_statistical_model( +function check_population_model( #Arguments from statistical model prior::Dict{T,D}, n_agents::I, diff --git a/src/fitting/single_agent_statistical_model.jl b/src/fitting/population_models/single_agent_population_model.jl similarity index 96% rename from src/fitting/single_agent_statistical_model.jl rename to src/fitting/population_models/single_agent_population_model.jl index 42ad89a..58cc91f 100644 --- a/src/fitting/single_agent_statistical_model.jl +++ b/src/fitting/population_models/single_agent_population_model.jl @@ -1,7 +1,7 @@ ########################################## ### STATISTICAL MODEL FOR SINGLE AGENT ### ########################################## -@model function single_statistical_model( +@model function single_agent_population_model( prior::Dict{T,D}, ) where {T<:Union{String,Tuple,Any},D<:Distribution} @@ -41,12 +41,12 @@ function create_model( data = DataFrame(hcat(inputs, actions), vcat(input_cols, action_cols)) #Create the single-agent statistical model - statistical_model = single_statistical_model(prior) + population_model = single_agent_population_model(prior) #Create a full model combining the agent model and the statistical model return create_model( agent, - statistical_model, + population_model, data; input_cols = input_cols, action_cols = action_cols, @@ -109,7 +109,7 @@ end ############################################################ ### CHECKS TO MAKE FOR THE SINGLE-AGENT STAISTICAL MODEL ### ############################################################ -function check_statistical_model( +function check_population_model( #Arguments from statistical model prior::Dict{T,D}; #Arguments from the agent diff --git a/test/testsuite/create_model_tests.jl b/test/testsuite/create_model_tests.jl index cffd3fa..b9ec0e8 100644 --- a/test/testsuite/create_model_tests.jl +++ b/test/testsuite/create_model_tests.jl @@ -65,18 +65,18 @@ using ActionModels, DataFrames renamed_model = rename_chains(fitted_model, model, data, :ID) #Create model with tracking states - model_tracked = create_model( - agent, - prior, - data, - input_cols = :inputs, - action_cols = :actions, - grouping_cols = :ID, - ) + # model_tracked = create_model( + # agent, + # prior, + # data, + # input_cols = :inputs, + # action_cols = :actions, + # grouping_cols = :ID, + # ) #Extract quantities - agent_parameters, agent_states, statistical_values = - extract_quantities(fitted_model, model_tracked) + # agent_parameters, agent_states, statistical_values = + # extract_quantities(fitted_model, model_tracked) end @testset "no grouping cols" begin @@ -102,19 +102,19 @@ using ActionModels, DataFrames #Rename chains renamed_model = rename_chains(fitted_model, model, data, [:ID, :category]) - #Create model with tracking states - model_tracked = create_model( - agent, - prior, - data, - input_cols = :inputs, - action_cols = :actions, - grouping_cols = [:ID, :category], - ) - - #Extract quantities - agent_parameters, agent_states, statistical_values = - extract_quantities(fitted_model, model_tracked) + # #Create model with tracking states + # model_tracked = create_model( + # agent, + # prior, + # data, + # input_cols = :inputs, + # action_cols = :actions, + # grouping_cols = [:ID, :category], + # ) + + # #Extract quantities + # agent_parameters, agent_states, statistical_values = + # extract_quantities(fitted_model, model_tracked) end @testset "missing actions" begin @@ -140,23 +140,28 @@ using ActionModels, DataFrames #Rename chains renamed_model = rename_chains(fitted_model, model, data, :ID) - #Create model with tracking states - model_tracked = create_model( - agent, - prior, - data, - input_cols = :inputs, - action_cols = :actions, - grouping_cols = :ID, - ) - - #Extract quantities - agent_parameters, agent_states, statistical_values = - extract_quantities(fitted_model, model_tracked) + # #Create model with tracking states + # model_tracked = create_model( + # agent, + # prior, + # data, + # input_cols = :inputs, + # action_cols = :actions, + # grouping_cols = :ID, + # ) + + # #Extract quantities + # agent_parameters, agent_states, statistical_values = + # extract_quantities(fitted_model, model_tracked) end @testset "custom statistical model" begin + end + + @testset "multiple actions, missing actions" begin + + end @testset "multiple actions" begin From 9109dfa0708e5568e54f1a0752428cfba49d4e12 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Wed, 2 Oct 2024 21:41:28 +0200 Subject: [PATCH 5/7] further optimizied agent_models by putting all action samples in one arraydist --- src/fitting/agent_model.jl | 162 +++++++++++++++++++++------ src/fitting/create_model.jl | 66 +++-------- src/structs.jl | 2 + test/testsuite/create_model_tests.jl | 2 +- 4 files changed, 149 insertions(+), 83 deletions(-) diff --git a/src/fitting/agent_model.jl b/src/fitting/agent_model.jl index 6833771..ce75033 100644 --- a/src/fitting/agent_model.jl +++ b/src/fitting/agent_model.jl @@ -1,69 +1,109 @@ ############################################### ### WITH SINGLE ACTION / NO MISSING ACTIONS ### ############################################### -@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Vector{R}) where {D<:Dict, I<:Vector, R<:Real} - - #Set the agent parameters - set_parameters!(agent, parameters) - reset!(agent) +@model function agent_models(agent::Agent, agent_ids::Vector{Symbol}, parameters_per_agent::Vector{D}, inputs_per_agent::Vector{I}, actions_per_agent::Vector{Vector{R}}, actions_flattened::Vector{R}, missing_actions::Nothing) where {D<:Dict, I<:Vector, R<:Real} + #TODO: Could use a list comprehension here to make it more efficient #Initialize a vector for storing the action probability distributions - action_distributions = Vector(undef, length(inputs)) + action_distributions = Vector(undef, length(actions_flattened)) - #Go through each timestep - for (timestep, (input, action)) in enumerate(zip(inputs, actions)) + #Initialize action index + action_idx = 0 + + #Go through each agent + for (agent_parameters, agent_inputs, agent_actions) in zip(parameters_per_agent, inputs_per_agent, actions_per_agent) - #Get the action probability distributions from the action model - @inbounds action_distributions[timestep] = agent.action_model(agent, input) + #Set the agent parameters + set_parameters!(agent, agent_parameters) + reset!(agent) - #Store the agent's action in the agent - update_states!(agent, "action", action) - end + #Go through each timestep + for (input, action) in zip(agent_inputs, agent_actions) + + #Increment one action index + action_idx += 1 + #Get the action probability distributions from the action model + @inbounds action_distributions[action_idx] = agent.action_model(agent, input) + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + end + #Make sure the action distributions are stored as a concrete type (by constructing a new vector) action_distributions = [dist for dist in action_distributions] #Sample the actions from the probability distributions - actions ~ arraydist(action_distributions) + actions_flattened ~ arraydist(action_distributions) end + ################################################## ### WITH MULTIPLE ACTIONS / NO MISSING ACTIONS ### ################################################## +@model function agent_models(agent::Agent, agent_ids::Vector{Symbol}, parameters_per_agent::Vector{D}, inputs_per_agent::Vector{I}, actions_per_agent::Vector{Matrix{R}}, actions_flattened::Matrix{R}, missing_actions::Nothing) where {D<:Dict, I<:Vector, R<:Real} -@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Matrix{R}) where {D<:Dict, I<:Vector, R<:Real} + #Initialize a vector for storing the action probability distributions + action_distributions = Matrix(undef, size(actions_flattened)...) - #Set the agent parameters - set_parameters!(agent, parameters) - reset!(agent) + #Initialize action index + action_idx = 0 + + #Go through each agent + for (agent_parameters, agent_inputs, agent_actions) in zip(parameters_per_agent, inputs_per_agent, actions_per_agent) - #Initialize a matrix for storing the action probability distributions - action_distributions = Matrix(undef, size(actions)...) + #Set the agent parameters + set_parameters!(agent, agent_parameters) + reset!(agent) - #Go through each timestep - for (timestep, (input, action)) in enumerate(zip(inputs, Tuple.(eachrow(actions)))) + #Go through each timestep + for (input, action) in zip(agent_inputs, Tuple.(eachrow(agent_actions))) - #Get the action probability distributions from the action model - action_distributions[timestep, :] = agent.action_model(agent, input) #TODO: can use @inbounds here when there's a check for whether the right amount of actions are used + #Increment one action index + action_idx += 1 - #Store the agent's action in the agent - update_states!(agent, "action", action) - end + #Get the action probability distributions from the action model + @inbounds action_distributions[action_idx, :] = agent.action_model(agent, input) - #Make sure the action distributions are stored as a concrete type (by constructing a new matrix) + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new vector) action_distributions = [dist for dist in action_distributions] #Sample the actions from the probability distributions - actions ~ arraydist(action_distributions) + actions_flattened ~ arraydist(action_distributions) +end + + + + + + + + + +############################################ +### WITH MISSING ACTIONS - SUPERFUNCTION ### +############################################ +@model function agent_models(agent::Agent, agent_ids::Vector{Symbol}, parameters_per_agent::Vector{D}, inputs_per_agent::Vector{I}, actions_per_agent::Vector{A}, actions_flattened::A, missing_actions::MissingActions) where {D<:Dict, I<:Vector, A<:Array} + + #For each agent + for (agent_id, agent_parameters, agent_inputs, agent_actions) in zip(agent_ids, parameters_per_agent, inputs_per_agent, actions_per_agent) + + #Fit it to the data + @submodel prefix = "$agent_id" agent_model(agent, agent_parameters, agent_inputs, agent_actions) + end end ################################################# ### WITH SINGLE ACTION / WITH MISSING ACTIONS ### ################################################# - @model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Vector{Union{Missing, R}}) where {D<:Dict, I<:Vector, R<:Real} - @show "yep" #Set the agent parameters set_parameters!(agent, parameters) reset!(agent) @@ -86,7 +126,6 @@ end end end - #################################################### ### WITH MULTIPLE ACTIONS / WITH MISSING ACTIONS ### #################################################### @@ -120,4 +159,63 @@ end ) #TODO: can use @inbounds here when there's a check for whether the right amount of actions are produced end +end + +############################################### +### WITH SINGLE ACTION / NO MISSING ACTIONS ### +############################################### +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Vector{R}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Initialize a vector for storing the action probability distributions + action_distributions = Vector(undef, length(inputs)) + + #Go through each timestep + for (timestep, (input, action)) in enumerate(zip(inputs, actions)) + + #Get the action probability distributions from the action model + @inbounds action_distributions[timestep] = agent.action_model(agent, input) + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new vector) + action_distributions = [dist for dist in action_distributions] + + #Sample the actions from the probability distributions + actions ~ arraydist(action_distributions) +end + +################################################## +### WITH MULTIPLE ACTIONS / NO MISSING ACTIONS ### +################################################## + +@model function agent_model(agent::Agent, parameters::D, inputs::I, actions::Matrix{R}) where {D<:Dict, I<:Vector, R<:Real} + + #Set the agent parameters + set_parameters!(agent, parameters) + reset!(agent) + + #Initialize a matrix for storing the action probability distributions + action_distributions = Matrix(undef, size(actions)...) + + #Go through each timestep + for (timestep, (input, action)) in enumerate(zip(inputs, Tuple.(eachrow(actions)))) + + #Get the action probability distributions from the action model + action_distributions[timestep, :] = agent.action_model(agent, input) #TODO: can use @inbounds here when there's a check for whether the right amount of actions are used + + #Store the agent's action in the agent + update_states!(agent, "action", action) + end + + #Make sure the action distributions are stored as a concrete type (by constructing a new matrix) + action_distributions = [dist for dist in action_distributions] + + #Sample the actions from the probability distributions + actions ~ arraydist(action_distributions) end \ No newline at end of file diff --git a/src/fitting/create_model.jl b/src/fitting/create_model.jl index ce90e4f..27b2a78 100644 --- a/src/fitting/create_model.jl +++ b/src/fitting/create_model.jl @@ -72,8 +72,17 @@ function create_model( #Extract agent id's as combined symbols in a vector agent_ids = [Symbol(join(string.(Tuple(row)), id_separator)) for row in eachrow(unique(data[!, grouping_cols]))] + ## Determine whether any actions are missing ## + if actions isa Vector{A} where {R<:Real, A<:Array{Union{Missing, R}}} + #If there are missing actions + missing_actions = MissingActions() + elseif actions isa Vector{A} where {R<:Real, A<:Array{R}} + #If there are no missing actions + missing_actions = nothing + end + #Create a full model combining the agent model and the statistical model - return full_model(agent_model, population_model, inputs, actions, agent_ids, check_parameter_rejections = check_parameter_rejections) + return full_model(agent_model, population_model, inputs, actions, agent_ids, missing_actions = missing_actions, check_parameter_rejections = check_parameter_rejections) end #################################################################### @@ -85,60 +94,17 @@ end inputs_per_agent::Vector{I}, actions_per_agent::Vector{A}, agent_ids::Vector{Symbol}; + missing_actions::Union{Nothing, MissingActions} = MissingActions(), check_parameter_rejections::Nothing = nothing, + actions_flattened::A = vcat(actions_per_agent...) ) where {I<:Vector, R<:Real, A1 <:Union{R,Union{Missing,R}}, A<:Array{A1}} #Generate the agent parameters from the statistical model @submodel population_values = population_model - ## For each agent ## - for (agent_id, agent_parameters, agent_inputs, agent_actions) in zip(agent_ids, population_values.agent_parameters, inputs_per_agent, actions_per_agent) - - @submodel prefix = "$agent_id" agent_model(agent, agent_parameters, agent_inputs, agent_actions) - end + #Generate the agent's behavior + @submodel agent_models(agent, agent_ids, population_values.agent_parameters, inputs_per_agent, actions_per_agent, actions_flattened, missing_actions) - #Return agents' parameters and tracked states + #Return values fron the population model (agent parameters and oher values) return population_values -end - - -############################################################################# -### WRAPPER FUNCTION FOR FULL_MODEL FOR CHECKING FOR PARAMETER REJECTIONS ### -############################################################################# - -# @model function full_model( -# agent::Agent, -# population_model::DynamicPPL.Model, -# inputs_per_agent::Array{IA}, -# actions_per_agent::Array{AA}; -# agent_ids::Vector{Union{Symbol,Vector{Symbol}}}, -# check_parameter_rejections::CheckRejections, -# ) where {IAR<:Union{Real,Missing},AAR<:Union{Real,Missing},IA<:Array{IAR},AA<:Array{AAR}} - -# #Check whether errors occur -# try - -# #Run the full model -# @submodel generated_quantities = full_model( -# agent, -# population_model, -# inputs_per_agent, -# actions_per_agent; -# agent_ids = agent_ids, -# check_parameter_rejections = nothing, -# ) - -# return generated_quantities - -# #If an error occurs -# catch error -# #If it is of the custom errortype RejectParameters -# if error isa RejectParameters -# #Make Turing reject the sample -# Turing.@addlogprob!(-Inf) -# else -# #Otherwise, just throw the error -# rethrow(error) -# end -# end -# end \ No newline at end of file +end \ No newline at end of file diff --git a/src/structs.jl b/src/structs.jl index d88250d..1cf3166 100644 --- a/src/structs.jl +++ b/src/structs.jl @@ -39,6 +39,8 @@ end struct CheckRejections end +struct MissingActions +end """ diff --git a/test/testsuite/create_model_tests.jl b/test/testsuite/create_model_tests.jl index b9ec0e8..db4dda4 100644 --- a/test/testsuite/create_model_tests.jl +++ b/test/testsuite/create_model_tests.jl @@ -161,7 +161,7 @@ using ActionModels, DataFrames @testset "multiple actions, missing actions" begin - + end @testset "multiple actions" begin From f27e62e32ec8bf9f840ea82c4a3a89869bbc9735 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Thu, 3 Oct 2024 08:37:11 +0200 Subject: [PATCH 6/7] minor --- src/fitting/agent_model.jl | 1 - src/fitting/create_model.jl | 2 +- test/testsuite/create_model_tests.jl | 30 +++++++++++++++++++--------- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/src/fitting/agent_model.jl b/src/fitting/agent_model.jl index ce75033..96745e0 100644 --- a/src/fitting/agent_model.jl +++ b/src/fitting/agent_model.jl @@ -148,7 +148,6 @@ end actions[timestep, action_idx] ~ single_distribution #TODO: can use @inbounds here when there's a check for whether the right amount of actions are produced - #TODO: Could use arraydist here if this was formatted as a vector of vectors (probably not!) end #Add the actions to the agent in case it needs it in the future diff --git a/src/fitting/create_model.jl b/src/fitting/create_model.jl index 27b2a78..162bc32 100644 --- a/src/fitting/create_model.jl +++ b/src/fitting/create_model.jl @@ -7,7 +7,7 @@ function create_model( data::DataFrame; input_cols::Union{Vector{T1},T1}, action_cols::Union{Vector{T2},T3}, - grouping_cols::Union{Vector{T3},T3} = Vector{Symbol}(), + grouping_cols::Union{Vector{T3},T3}, check_parameter_rejections::Union{Nothing, CheckRejections} = nothing, id_separator::String = "__", verbose::Bool = true, diff --git a/test/testsuite/create_model_tests.jl b/test/testsuite/create_model_tests.jl index db4dda4..dfa7b69 100644 --- a/test/testsuite/create_model_tests.jl +++ b/test/testsuite/create_model_tests.jl @@ -79,8 +79,24 @@ using ActionModels, DataFrames # extract_quantities(fitted_model, model_tracked) end + @testset "custom statistical model" begin + + end + @testset "no grouping cols" begin + #Create model + model = create_model( + agent, + prior, + data, + input_cols = :inputs, + action_cols = :actions, + grouping_cols = Symbol[] + ) + #Fit model + fitted_model = + sample(model, sampler, n_iterations; n_chains = n_chains, sampling_kwargs...) end @testset "multiple grouping cols" begin @@ -155,15 +171,6 @@ using ActionModels, DataFrames # extract_quantities(fitted_model, model_tracked) end - @testset "custom statistical model" begin - - end - - @testset "multiple actions, missing actions" begin - - - end - @testset "multiple actions" begin #Action model with multiple actions @@ -195,6 +202,11 @@ using ActionModels, DataFrames fitted_model = sample(model, sampler, n_iterations; n_chains = n_chains, sampling_kwargs...) + end + + @testset "multiple actions, missing actions" begin + + end @testset "multiple inputs" begin From f53d1110ec0df040cb45d59d59567514967fb887 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Thu, 3 Oct 2024 10:36:19 +0200 Subject: [PATCH 7/7] 0.6.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 25df1be..8d0023d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ActionModels" uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" authors = ["Peter Thestrup Waade ptw@cas.au.dk", "Anna Hedvig Møller hedvig.2808@gmail.com", "Jacopo Comoglio jacopo.comoglio@gmail.com", "Christoph Mathys chmathys@cas.au.dk"] -version = "0.6.2" +version = "0.6.3" [deps] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"