Skip to content

Commit

Permalink
finalized plots, changed reexport of Turing to reexport of Distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Oct 8, 2024
1 parent adddc97 commit fafe6c7
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 70 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.6.3"
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Expand All @@ -16,14 +17,15 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
AxisArrays = "0.4"
DataFrames = "1.6"
Distributed = "1"
Distributions = "0.25"
ForwardDiff = "0.10"
Logging = "1"
ProgressMeter = "1"
RecipesBase = "1.3"
Reexport = "1"
ReverseDiff = "1.15"
Turing = "0.34"
AxisArrays = "0.4"
julia = "1.10"
10 changes: 4 additions & 6 deletions src/ActionModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,16 @@ module ActionModels

#Load packages
using Reexport
@reexport using Turing
#using Distributions, DataFrames, RecipesBase, Logging
using DataFrames, RecipesBase, ReverseDiff, Logging, AxisArrays
using Turing: Distributions, DynamicPPL, ForwardDiff, AutoReverseDiff, AbstractMCMC
using Turing, ReverseDiff, DataFrames, AxisArrays, RecipesBase, Logging
using ProgressMeter, Distributed #TODO: get rid of this (only needed for parameter recovery)
@reexport using Distributions
using Turing: DynamicPPL, ForwardDiff, AutoReverseDiff, AbstractMCMC
#Export functions
export Agent, RejectParameters, InitialStateParameter, ParameterGroup
export init_agent, premade_agent, warn_premade_defaults, multiple_actions, check_agent
export independent_agents_population_model,
create_model, fit_model, parameter_recovery, single_recovery
export plot_parameters,
plot_trajectories, plot_trajectory, plot_trajectory!
export plot_parameters, plot_trajectories, plot_trajectory, plot_trajectory!
export get_history,
get_states,
get_parameters,
Expand Down
3 changes: 2 additions & 1 deletion src/fitting/helper_functions/get_estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ function get_estimates(
end

# Add the value to the row
row[Symbol(join((string(agent), string(state)), id_separator))] = median_value
row[Symbol(join((string(agent), string(state)), id_separator))] =
median_value
end

#Add the timestep to the row
Expand Down
2 changes: 1 addition & 1 deletion src/plots/plot_parameters.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@userplot struct Plot_Parameters{T}
@userplot struct Plot_Parameters{T<:Tuple{Chains,Chains}}
args::T
end
"""
Expand Down
66 changes: 35 additions & 31 deletions src/plots/plot_trajectories.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,38 @@
@userplot struct Plot_Trajectories{T<:Tuple{AxisArrays.AxisArray{
Union{Missing,Float64},
5,
Array{Union{Missing,Float64},5},
Tuple{
AxisArrays.Axis{:agent,Vector{Symbol}},
AxisArrays.Axis{:state,Vector{Symbol}},
AxisArrays.Axis{:timestep,UnitRange{Int64}},
AxisArrays.Axis{:sample,UnitRange{Int64}},
AxisArrays.Axis{:chain,UnitRange{Int64}},
},
},
}}
@userplot struct Plot_Trajectories{
T<:Tuple{
AxisArrays.AxisArray{
Union{Missing,Float64},
5,
Array{Union{Missing,Float64},5},
Tuple{
AxisArrays.Axis{:agent,Vector{Symbol}},
AxisArrays.Axis{:state,Vector{Symbol}},
AxisArrays.Axis{:timestep,UnitRange{Int64}},
AxisArrays.Axis{:sample,UnitRange{Int64}},
AxisArrays.Axis{:chain,UnitRange{Int64}},
},
},
},
}
args::T
end

"""
"""
plot_trajectories
@recipe function f(
plt::Plot_Trajectories,
sample_color::Union{String,Symbol} = :gray,
sample_alpha::Real = 0.1,
sample_linewidth::Real = 0.5,

summary_function::Function = median,
summary_alpha::Real = 1,
summary_color::Union{String,Symbol} = :red,
summary_linewidth::Real = 1,

plot_width::Int = 800,
plot_height::Int = 600,
subplot_titles = Dict()
)
plt::Plot_Trajectories,
sample_color::Union{String,Symbol} = :gray,
sample_alpha::Real = 0.1,
sample_linewidth::Real = 0.5,
summary_function::Function = median,
summary_alpha::Real = 1,
summary_color::Union{String,Symbol} = :red,
summary_linewidth::Real = 1,
plot_width::Int = 800,
plot_height::Int = 600,
subplot_titles = Dict(),
)

#Extract trajectories
trajectories = plt.args[1]
Expand Down Expand Up @@ -72,8 +73,8 @@ plot_trajectories
#Set the font size
legendfontsize --> 15

#Plot samples for each agent
for agent_id in agent_ids
#Plot samples for each agent
for agent_id in agent_ids
#For each chain and sample
for chain in chains
for sample in samples
Expand All @@ -99,7 +100,10 @@ plot_trajectories
for agent_id in agent_ids

#Get vector of point estimates
summary_values = [summary_function(trajectories[agent_id, state_key, timestep+1, :, :]) for timestep in timesteps]
summary_values = [
summary_function(trajectories[agent_id, state_key, timestep+1, :, :])
for timestep in timesteps
]

#Plot the summary value
@series begin
Expand All @@ -113,4 +117,4 @@ plot_trajectories
end
end
end
end
end
58 changes: 28 additions & 30 deletions test/testsuite/create_model_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Test
using StatsPlots
using ActionModels, DataFrames
using AxisArrays
using AxisArrays, Turing


@testset "fitting tests" begin
Expand Down Expand Up @@ -111,8 +111,7 @@ using AxisArrays
estimates_dict = get_estimates(agent_parameters, Dict)

#Extract state trajectories
state_trajectories =
get_trajectories(model, fitted_model, ["value", "action"])
state_trajectories = get_trajectories(model, fitted_model, ["value", "action"])
trajectory_estimates_df = get_estimates(state_trajectories)

#Check that the learning rates are estimated right
Expand Down Expand Up @@ -146,10 +145,9 @@ using AxisArrays
prior_chains = sample(model, Prior(), n_iterations; sampling_kwargs...)
renamed_prior_chains = rename_chains(prior_chains, model)

plot_parameters(prior_chains, renamed_model)
plot_parameters(renamed_prior_chains, renamed_model)

prior_trajectories =
get_trajectories(model, prior_chains, ["value", "action"])
prior_trajectories = get_trajectories(model, prior_chains, ["value", "action"])
plot_trajectories(prior_trajectories)
end

Expand Down Expand Up @@ -397,28 +395,25 @@ end

function plot_trajectory(
trajectories = AxisArrays.AxisArray{
Union{Missing,Float64},
5,
Array{Union{Missing,Float64},5},
Tuple{
AxisArrays.Axis{:agent,Vector{Symbol}},
AxisArrays.Axis{:state,Vector{Symbol}},
AxisArrays.Axis{:timestep,UnitRange{Int64}},
AxisArrays.Axis{:sample,UnitRange{Int64}},
AxisArrays.Axis{:chain,UnitRange{Int64}},
},
Union{Missing,Float64},
5,
Array{Union{Missing,Float64},5},
Tuple{
AxisArrays.Axis{:agent,Vector{Symbol}},
AxisArrays.Axis{:state,Vector{Symbol}},
AxisArrays.Axis{:timestep,UnitRange{Int64}},
AxisArrays.Axis{:sample,UnitRange{Int64}},
AxisArrays.Axis{:chain,UnitRange{Int64}},
},

sample_color::Union{String,Symbol} = :gray,
sample_alpha::Real = 0.1,
sample_linewidth::Real = 1,

summary_function::Function = median,
summary_alpha::Real = 1,
summary_color::Union{String,Symbol} = :red,
summary_linewidth::Real = 2,

)
},
sample_color::Union{String,Symbol} = :gray,
sample_alpha::Real = 0.1,
sample_linewidth::Real = 1,
summary_function::Function = median,
summary_alpha::Real = 1,
summary_color::Union{String,Symbol} = :red,
summary_linewidth::Real = 2,
)

#Extract dimensions
agent_ids, state_keys, timesteps, samples, chains = trajectories.axes
Expand All @@ -427,7 +422,7 @@ function plot_trajectory(
plots = Vector(undef, length(state_keys))

#For each state
for (state_idx,state_key) in enumerate(state_keys)
for (state_idx, state_key) in enumerate(state_keys)
#Initialize plot
plots[state_idx] = plot()

Expand Down Expand Up @@ -458,7 +453,10 @@ function plot_trajectory(
for agent_id in agent_ids

#Get vector of medians
summary_values = [summary_function(trajectories[agent_id, state_key, timestep+1, :, :]) for timestep in timesteps]
summary_values = [
summary_function(trajectories[agent_id, state_key, timestep+1, :, :])
for timestep in timesteps
]

#Plot the summary value
plot!(
Expand All @@ -478,4 +476,4 @@ function plot_trajectory(

#Plot all plots
plot(plots..., layput = (length(state_keys), 1))
end
end

0 comments on commit fafe6c7

Please sign in to comment.