|
| 1 | +// Copyright (C) 2023-2024 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +#include <cassert> |
| 5 | +#include <random> |
| 6 | +#include <fstream> |
| 7 | +#include <iterator> |
| 8 | + |
| 9 | +#include "text2image/schedulers/ddim.hpp" |
| 10 | +#include "utils.hpp" |
| 11 | +#include "text2image/numpy_utils.hpp" |
| 12 | + |
| 13 | +namespace ov { |
| 14 | +namespace genai { |
| 15 | + |
| 16 | +DDIMScheduler::Config::Config(const std::string& scheduler_config_path) { |
| 17 | + std::ifstream file(scheduler_config_path); |
| 18 | + OPENVINO_ASSERT(file.is_open(), "Failed to open ", scheduler_config_path); |
| 19 | + |
| 20 | + nlohmann::json data = nlohmann::json::parse(file); |
| 21 | + using utils::read_json_param; |
| 22 | + |
| 23 | + read_json_param(data, "num_train_timesteps", num_train_timesteps); |
| 24 | + read_json_param(data, "beta_start", beta_start); |
| 25 | + read_json_param(data, "beta_end", beta_end); |
| 26 | + read_json_param(data, "beta_schedule", beta_schedule); |
| 27 | + read_json_param(data, "trained_betas", trained_betas); |
| 28 | + read_json_param(data, "clip_sample", clip_sample); |
| 29 | + read_json_param(data, "set_alpha_to_one", set_alpha_to_one); |
| 30 | + read_json_param(data, "steps_offset", steps_offset); |
| 31 | + read_json_param(data, "prediction_type", prediction_type); |
| 32 | + read_json_param(data, "thresholding", thresholding); |
| 33 | + read_json_param(data, "dynamic_thresholding_ratio", dynamic_thresholding_ratio); |
| 34 | + read_json_param(data, "clip_sample_range", clip_sample_range); |
| 35 | + read_json_param(data, "sample_max_value", sample_max_value); |
| 36 | + read_json_param(data, "timestep_spacing", timestep_spacing); |
| 37 | + read_json_param(data, "rescale_betas_zero_snr", rescale_betas_zero_snr); |
| 38 | +} |
| 39 | + |
| 40 | +DDIMScheduler::DDIMScheduler(const std::string scheduler_config_path) |
| 41 | + : DDIMScheduler(Config(scheduler_config_path)) { |
| 42 | +} |
| 43 | + |
| 44 | +DDIMScheduler::DDIMScheduler(const Config& scheduler_config) |
| 45 | + : m_config(scheduler_config) { |
| 46 | + |
| 47 | + std::vector<float> alphas, betas; |
| 48 | + |
| 49 | + using numpy_utils::linspace; |
| 50 | + |
| 51 | + if (!m_config.trained_betas.empty()) { |
| 52 | + betas = m_config.trained_betas; |
| 53 | + } else if (m_config.beta_schedule == BetaSchedule::LINEAR) { |
| 54 | + betas = linspace<float>(m_config.beta_start, m_config.beta_end, m_config.num_train_timesteps); |
| 55 | + } else if (m_config.beta_schedule == BetaSchedule::SCALED_LINEAR) { |
| 56 | + float start = std::sqrt(m_config.beta_start); |
| 57 | + float end = std::sqrt(m_config.beta_end); |
| 58 | + betas = linspace<float>(start, end, m_config.num_train_timesteps); |
| 59 | + std::for_each(betas.begin(), betas.end(), [] (float & x) { x *= x; }); |
| 60 | + } else { |
| 61 | + OPENVINO_THROW("'beta_schedule' must be one of 'LINEAR' or 'SCALED_LINEAR'. Please, add support of other types"); |
| 62 | + } |
| 63 | + |
| 64 | + // TODO: Rescale for zero SNR |
| 65 | + // if (m_config.rescale_betas_zero_snr) {betas = rescale_zero_terminal_snr(betas)} |
| 66 | + |
| 67 | + std::transform(betas.begin(), betas.end(), std::back_inserter(alphas), [] (float b) { return 1.0f - b; }); |
| 68 | + |
| 69 | + for (size_t i = 1; i <= alphas.size(); i++) { |
| 70 | + float alpha_cumprod = |
| 71 | + std::accumulate(std::begin(alphas), std::begin(alphas) + i, 1.0, std::multiplies<float>{}); |
| 72 | + m_alphas_cumprod.push_back(alpha_cumprod); |
| 73 | + } |
| 74 | + |
| 75 | + m_final_alpha_cumprod = m_config.set_alpha_to_one ? 1 : m_alphas_cumprod[0]; |
| 76 | +} |
| 77 | + |
| 78 | +void DDIMScheduler::set_timesteps(size_t num_inference_steps) { |
| 79 | + m_timesteps.clear(); |
| 80 | + |
| 81 | + OPENVINO_ASSERT(num_inference_steps <= m_config.num_train_timesteps, |
| 82 | + "`num_inference_steps` cannot be larger than `m_config.num_train_timesteps`"); |
| 83 | + |
| 84 | + m_num_inference_steps = num_inference_steps; |
| 85 | + |
| 86 | + switch (m_config.timestep_spacing) { |
| 87 | + case TimestepSpacing::LINSPACE: |
| 88 | + { |
| 89 | + using numpy_utils::linspace; |
| 90 | + float end = static_cast<float>(m_config.num_train_timesteps - 1); |
| 91 | + auto linspaced = linspace<float>(0.0f, end, num_inference_steps, true); |
| 92 | + for (auto it = linspaced.rbegin(); it != linspaced.rend(); ++it) { |
| 93 | + m_timesteps.push_back(static_cast<int64_t>(std::round(*it))); |
| 94 | + } |
| 95 | + break; |
| 96 | + } |
| 97 | + case TimestepSpacing::LEADING: |
| 98 | + { |
| 99 | + size_t step_ratio = m_config.num_train_timesteps / m_num_inference_steps; |
| 100 | + for (size_t i = num_inference_steps - 1; i != -1; --i) { |
| 101 | + m_timesteps.push_back(i * step_ratio + m_config.steps_offset); |
| 102 | + } |
| 103 | + break; |
| 104 | + } |
| 105 | + case TimestepSpacing::TRAILING: |
| 106 | + { |
| 107 | + float step_ratio = static_cast<float>(m_config.num_train_timesteps) / static_cast<float>(m_num_inference_steps); |
| 108 | + for (float i = m_config.num_train_timesteps; i > 0; i-=step_ratio){ |
| 109 | + m_timesteps.push_back(static_cast<int64_t>(std::round(i)) - 1); |
| 110 | + } |
| 111 | + break; |
| 112 | + } |
| 113 | + default: |
| 114 | + OPENVINO_THROW("Unsupported value for 'timestep_spacing'"); |
| 115 | + } |
| 116 | +} |
| 117 | + |
| 118 | +std::map<std::string, ov::Tensor> DDIMScheduler::step(ov::Tensor noise_pred, ov::Tensor latents, size_t inference_step) { |
| 119 | + // noise_pred - model_output |
| 120 | + // latents - sample |
| 121 | + // inference_step |
| 122 | + |
| 123 | + size_t timestep = get_timesteps()[inference_step]; |
| 124 | + |
| 125 | + // get previous step value (=t-1) |
| 126 | + int prev_timestep = timestep - m_config.num_train_timesteps / m_num_inference_steps; |
| 127 | + |
| 128 | + // compute alphas, betas |
| 129 | + float alpha_prod_t = m_alphas_cumprod[timestep]; |
| 130 | + float alpha_prod_t_prev = (prev_timestep >= 0) ? m_alphas_cumprod[prev_timestep] : m_final_alpha_cumprod; |
| 131 | + float beta_prod_t = 1 - alpha_prod_t; |
| 132 | + |
| 133 | + // compute predicted original sample from predicted noise also called |
| 134 | + // "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 135 | + std::vector<float> pred_original_sample, pred_epsilon; |
| 136 | + float pos_val, pe_val; |
| 137 | + for (size_t j = 0; j < noise_pred.get_size(); j++) { |
| 138 | + switch (m_config.prediction_type) { |
| 139 | + case PredictionType::EPSILON: |
| 140 | + pos_val = (latents.data<float>()[j] - std::sqrt(beta_prod_t) * noise_pred.data<float>()[j]) / std::sqrt(alpha_prod_t); |
| 141 | + pe_val = noise_pred.data<float>()[j]; |
| 142 | + pred_original_sample.push_back(pos_val); |
| 143 | + pred_epsilon.push_back(pe_val); |
| 144 | + break; |
| 145 | + case PredictionType::SAMPLE: |
| 146 | + pos_val = noise_pred.data<float>()[j]; |
| 147 | + pe_val = (latents.data<float>()[j] - std::sqrt(alpha_prod_t) * pos_val) / std::sqrt(beta_prod_t); |
| 148 | + pred_original_sample.push_back(pos_val); |
| 149 | + pred_epsilon.push_back(pe_val); |
| 150 | + break; |
| 151 | + case PredictionType::V_PREDICTION: |
| 152 | + pos_val = std::sqrt(alpha_prod_t) * latents.data<float>()[j] - std::sqrt(beta_prod_t) * noise_pred.data<float>()[j]; |
| 153 | + pe_val = std::sqrt(alpha_prod_t) * noise_pred.data<float>()[j] + std::sqrt(beta_prod_t) * latents.data<float>()[j]; |
| 154 | + pred_original_sample.push_back(pos_val); |
| 155 | + pred_epsilon.push_back(pe_val); |
| 156 | + break; |
| 157 | + default: |
| 158 | + OPENVINO_THROW("Unsupported value for 'PredictionType'"); |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + // TODO: Clip or threshold "predicted x_0" |
| 163 | + // if m_config.thresholding: |
| 164 | + // pred_original_sample = _threshold_sample(pred_original_sample) |
| 165 | + // elif m_config.clip_sample: |
| 166 | + // pred_original_sample = pred_original_sample.clamp( |
| 167 | + // -self.config.clip_sample_range, self.config.clip_sample_range |
| 168 | + // ) |
| 169 | + |
| 170 | + // compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 171 | + std::vector<float> pred_sample_direction(pred_epsilon.size()); |
| 172 | + std::transform(pred_epsilon.begin(), pred_epsilon.end(), pred_sample_direction.begin(), [alpha_prod_t_prev](auto x) { |
| 173 | + return std::sqrt(1 - alpha_prod_t_prev) * x; |
| 174 | + }); |
| 175 | + |
| 176 | + // compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf |
| 177 | + ov::Tensor prev_sample(latents.get_element_type(), latents.get_shape()); |
| 178 | + float* prev_sample_data = prev_sample.data<float>(); |
| 179 | + for (size_t i = 0; i < prev_sample.get_size(); ++i) { |
| 180 | + prev_sample_data[i] = std::sqrt(alpha_prod_t_prev) * pred_original_sample[i] + pred_sample_direction[i]; |
| 181 | + } |
| 182 | + |
| 183 | + std::map<std::string, ov::Tensor> result{{"latent", prev_sample}}; |
| 184 | + |
| 185 | + return result; |
| 186 | +} |
| 187 | + |
| 188 | +std::vector<int64_t> DDIMScheduler::get_timesteps() const { |
| 189 | + return m_timesteps; |
| 190 | +} |
| 191 | + |
| 192 | +float DDIMScheduler::get_init_noise_sigma() const { |
| 193 | + return 1.0f; |
| 194 | +} |
| 195 | + |
| 196 | +void DDIMScheduler::scale_model_input(ov::Tensor sample, size_t inference_step) { |
| 197 | + return; |
| 198 | +} |
| 199 | + |
| 200 | +} // namespace genai |
| 201 | +} // namespace ov |
0 commit comments