Skip to content

Commit

Permalink
finalized format for get_estimates
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Oct 8, 2024
1 parent 3278773 commit d0d7a58
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 47 deletions.
1 change: 1 addition & 0 deletions src/ActionModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ end
include("structs.jl")

const id_separator = "."
const id_column_separator = ":"
const tuple_separator = "__"

#Functions for creating agents
Expand Down
31 changes: 12 additions & 19 deletions src/fitting/create_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,25 +77,18 @@ function create_model(
]
end

#If there is only one grouping column
if length(grouping_cols) == 1
#Use the vaues fo that column as ID
agent_ids = [Symbol(i) for i in unique(data[!, first(grouping_cols)])]
else
#Otherwise, use combinations of the column name and their values
agent_ids = [
Symbol(
join(
[
string(col_name) * "" * string(row[col_name]) for
col_name in grouping_cols
],
id_separator,
),
) for row in eachrow(unique(data[!, grouping_cols]))
]

end
#Create agent ids from the grouping columns and their values
agent_ids = [
Symbol(
join(
[
string(col_name) * id_column_separator * string(row[col_name]) for
col_name in grouping_cols
],
id_separator,
),
) for row in eachrow(unique(data[!, grouping_cols]))
]

## Determine whether any actions are missing ##
if actions isa Vector{A} where {R<:Real,A<:Array{Union{Missing,R}}}
Expand Down
72 changes: 46 additions & 26 deletions src/fitting/helper_functions/get_estimates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,29 +45,41 @@ function get_estimates(
agents = agent_parameters.axes[1]
parameters = agent_parameters.axes[2]

#Construct grouping column names
grouping_cols = [Symbol(first(split(i, id_column_separator))) for i in split(string(first(agents)), id_separator)]

# Initialize an empty DataFrame
df = DataFrame(Dict(Symbol(parameter) => Float64[] for parameter in parameters))
df[!, :agent] = Symbol[]
#Add grouping colnames
for column_name in grouping_cols
df[!, column_name] = String[]
end

# Populate the DataFrame with median values
for agent in agents
for agent_id in agents
row = Dict()
for parameter in parameters
# Extract the values for the current agent and parameter across samples and chains
values = agent_parameters[agent, parameter, :, :]
values = agent_parameters[agent_id, parameter, :, :]
# Calculate the median value
median_value = summary_function(values)
# Add the median value to the row
row[Symbol(parameter)] = median_value
end
#Add an agent id to the row
row[:agent] = agent

#Split agent ids
split_agent_ids = split(string(agent_id), id_separator)
#Add them to the row
for (agent_id_part, column_name) in zip(split_agent_ids, grouping_cols)
row[column_name] = string(split(agent_id_part, id_column_separator)[2])
end

# Add the row to the DataFrame
push!(df, row)
end

# Reorder the columns to have agent_id as the first column
select!(df, :agent, names(df)[1:end-1]...)
# Reorder the columns to have agent id's as the first columns
select!(df, grouping_cols, names(df)[1:end-length(grouping_cols)]...)

return df
end
Expand Down Expand Up @@ -121,9 +133,6 @@ end






###########################################################################################
###### FUNCTION FOR GENERATING SUMMARIZED VARIABLES FROM AN AGENT_PARMAETERS AXISARRAY ####
###########################################################################################
Expand All @@ -148,31 +157,36 @@ function get_estimates(
states = state_trajectories.axes[2]
timesteps = state_trajectories.axes[3]

# Initialize an empty DataFrame
#Construct grouping column names
grouping_cols = [Symbol(first(split(i, id_column_separator))) for i in split(string(first(agents)), id_separator)]

# Initialize an empty DataFrame with the states, the grouping columns and the timestep
df = DataFrame(
Dict(
begin
#Join tuples
if state isa Tuple
state = join(state, tuple_separator)
end

#Join the agent and the state
Symbol(join((string(agent), string(state)), id_separator)) => Float64[]
end for (state, agent) in Iterators.product(states, agents)
state => Float64[]
end for state in states
),
)
for column_name in grouping_cols
df[!, column_name] = String[]
end
df[!, :timestep] = Int[]


# Populate the DataFrame with median values
for timestep in timesteps
row = Dict()

for agent in agents
for agent_id in agents

for timestep in timesteps
row = Dict()

for state in states
# Extract the state for the current agent and state, at the current timestep
values = state_trajectories[agent, state, timestep+1, :, :]
values = state_trajectories[agent_id, state, timestep+1, :, :]
# Calculate the point estimate
median_value = summary_function(values)

Expand All @@ -182,20 +196,26 @@ function get_estimates(
end

# Add the value to the row
row[Symbol(join((string(agent), string(state)), id_separator))] =
median_value
row[state] = median_value
end

#Split agent ids
split_agent_ids = split(string(agent_id), id_separator)
#Add them to the row
for (agent_id_part, column_name) in zip(split_agent_ids, grouping_cols)
row[column_name] = string(split(agent_id_part, id_column_separator)[2])
end

#Add the timestep to the row
row[:timestep] = timestep

# Add the row to the DataFrame
push!(df, row, promote = true)
end

# Add the row to the DataFrame
push!(df, row, promote = true)
end

# Reorder the columns to have agent_id as the first column
select!(df, :timestep, names(df)[1:end-1]...)
select!(df, vcat(grouping_cols,[:timestep]), names(df)[1:end-(length(grouping_cols)+1)]...)

return df
end
11 changes: 9 additions & 2 deletions test/testsuite/create_model_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Test
using StatsPlots
using ActionModels, DataFrames
using AxisArrays, Turing
using Turing: AutoReverseDiff


@testset "fitting tests" begin
Expand Down Expand Up @@ -191,16 +192,22 @@ using AxisArrays, Turing

#Fit model
fitted_model = sample(model, sampler, n_iterations; sampling_kwargs...)
#Rename chains
renamed_model = rename_chains(fitted_model, model)

#Extract quantities
agent_parameters = extract_quantities(model, fitted_model)
estimates_df = get_estimates(agent_parameters)

#Extract state trajectories
state_trajectories = get_trajectories(model, fitted_model, ["value", "action"])
trajectory_estimates_df = get_estimates(state_trajectories)


#Check that the learning rates are estimated right
@test estimates_df[!, :learning_rate] == sort(estimates_df[!, :learning_rate])

#Rename chains
renamed_model = rename_chains(fitted_model, model)

end

@testset "missing actions" begin
Expand Down

0 comments on commit d0d7a58

Please sign in to comment.