diff --git a/optd-persistent/src/cost_model/orm.rs b/optd-persistent/src/cost_model/orm.rs index 9c068ae..74c03f1 100644 --- a/optd-persistent/src/cost_model/orm.rs +++ b/optd-persistent/src/cost_model/orm.rs @@ -534,7 +534,7 @@ impl CostModelStorageLayer for BackendManager { epoch_id: Option, ) -> StorageResult<()> { assert!(cost.is_some() || estimated_statistic.is_some()); - // TODO: should we do the following checks in the production environment? + // TODO: we shouldn't do the following checks in the production environment. let expr_exists = PhysicalExpression::find_by_id(physical_expression_id) .one(&self.db) .await?; @@ -561,30 +561,54 @@ impl CostModelStorageLayer for BackendManager { } } - let epoch_id = match epoch_id { - Some(id) => id, - None => { - // When init, please make sure there is at least one epoch in the Event table. - let latest_epoch_id = Event::find() - .order_by_desc(event::Column::EpochId) - .one(&self.db) - .await? - .unwrap(); - latest_epoch_id.epoch_id - } - }; - let transaction = self.db.begin().await?; - let valid_cost = PlanCost::find() - .filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id)) - .filter(plan_cost::Column::EpochId.eq(epoch_id)) - .filter(plan_cost::Column::IsValid.eq(true)) - .one(&transaction) - .await?; + /* + The `store_cost` logic is as follows: + 1. If the epoch_id is provided, we should update the cost with the corresponding epoch_id, + or insert a new record if it doesn't exist. + 2. If the epoch_id is not provided, we cannot directly use the latest epoch_id, since in the + plan_cost table, for the current physical expression, there may be a valid cost with a lower + epoch_id, since the update_stats function updates unrelated stats. So we need to handle the + epoch_id in following logics: + 1) If a valid cost is already in the plan_cost table, we use the same epoch_id. + 2) If there is no valid cost in the plan_cost table, or there is no record, we use the + latest epoch_id. + */ + // TODO: We should add some integration tests to fully test the above logic + let epoch_id_data; + let existed_cost; + if let Some(epoch_id) = epoch_id { + epoch_id_data = epoch_id; + existed_cost = PlanCost::find() + .filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id)) + .filter(plan_cost::Column::EpochId.eq(epoch_id)) + .one(&transaction) + .await?; + } else { + existed_cost = PlanCost::find() + .filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id)) + .filter(plan_cost::Column::IsValid.eq(true)) + .order_by_desc(plan_cost::Column::EpochId) + .one(&transaction) + .await?; + if existed_cost.is_none() { + epoch_id_data = { + // When init, please make sure there is at least one epoch in the Event table. + let latest_epoch_id = Event::find() + .order_by_desc(event::Column::EpochId) + .one(&self.db) + .await? + .unwrap(); + latest_epoch_id.epoch_id + } + } else { + epoch_id_data = existed_cost.clone().unwrap().epoch_id; + } + } - if valid_cost.is_some() { - let mut new_cost: plan_cost::ActiveModel = valid_cost.unwrap().into(); + if existed_cost.is_some() { + let mut new_cost: plan_cost::ActiveModel = existed_cost.unwrap().into(); let mut update = false; if cost.is_some() { let input_cost = sea_orm::ActiveValue::Set(Some(json!({ @@ -604,12 +628,25 @@ impl CostModelStorageLayer for BackendManager { } } if update { + assert!(new_cost.epoch_id.is_unchanged()); let _ = PlanCost::update(new_cost).exec(&transaction).await?; } } else { + // TODO: we shouldn't do the following checks in the production environment. + // This check may be easy to violate, so consider removing epoch_id input parameter. + let latest_cost = PlanCost::find() + .filter(plan_cost::Column::PhysicalExpressionId.eq(physical_expression_id)) + .order_by_desc(plan_cost::Column::EpochId) + .one(&transaction) + .await?; + if latest_cost.is_some() { + assert!(latest_cost.clone().unwrap().epoch_id < epoch_id_data); + assert!(!latest_cost.clone().unwrap().is_valid); + } + let new_cost = plan_cost::ActiveModel { physical_expression_id: sea_orm::ActiveValue::Set(physical_expression_id), - epoch_id: sea_orm::ActiveValue::Set(epoch_id), + epoch_id: sea_orm::ActiveValue::Set(epoch_id_data), cost: sea_orm::ActiveValue::Set( cost.map(|c| json!({"compute_cost": c.compute_cost, "io_cost": c.io_cost})), ), @@ -1035,6 +1072,18 @@ mod tests { .create_new_epoch("source".to_string(), "data".to_string()) .await .unwrap(); + let stat = Stat { + stat_type: StatType::TableRowCount, + stat_value: json!(10), + attr_ids: vec![], + table_id: Some(1), + name: "row_count".to_owned(), + }; + let res = backend_manager + .update_stats(stat, EpochOption::Existed(epoch_id)) + .await; + assert!(res.is_ok()); + let physical_expression_id = 1; let cost = Cost { compute_cost: 42.0, @@ -1102,6 +1151,18 @@ mod tests { .create_new_epoch("source".to_string(), "data".to_string()) .await .unwrap(); + let stat = Stat { + stat_type: StatType::TableRowCount, + stat_value: json!(10), + attr_ids: vec![], + table_id: Some(1), + name: "row_count".to_owned(), + }; + let res = backend_manager + .update_stats(stat, EpochOption::Existed(epoch_id)) + .await; + assert!(res.is_ok()); + let physical_expression_id = 1; let cost = Cost { compute_cost: 42.0, @@ -1148,6 +1209,18 @@ mod tests { .create_new_epoch("source".to_string(), "data".to_string()) .await .unwrap(); + let stat = Stat { + stat_type: StatType::TableRowCount, + stat_value: json!(10), + attr_ids: vec![], + table_id: Some(1), + name: "row_count".to_owned(), + }; + let res = backend_manager + .update_stats(stat, EpochOption::Existed(epoch_id)) + .await; + assert!(res.is_ok()); + let physical_expression_id = 1; let estimated_statistic = 42.0; let _ = backend_manager