@@ -61,8 +61,9 @@ DDIMScheduler::DDIMScheduler(const Config& scheduler_config)
61
61
OPENVINO_THROW (" 'beta_schedule' must be one of 'LINEAR' or 'SCALED_LINEAR'. Please, add support of other types" );
62
62
}
63
63
64
- // TODO: Rescale for zero SNR
65
- // if (m_config.rescale_betas_zero_snr) {betas = rescale_zero_terminal_snr(betas)}
64
+ if (m_config.rescale_betas_zero_snr ) {
65
+ rescale_zero_terminal_snr (betas);
66
+ }
66
67
67
68
std::transform (betas.begin (), betas.end (), std::back_inserter (alphas), [] (float b) { return 1 .0f - b; });
68
69
@@ -159,13 +160,12 @@ std::map<std::string, ov::Tensor> DDIMScheduler::step(ov::Tensor noise_pred, ov:
159
160
}
160
161
}
161
162
162
- // TODO: Clip or threshold "predicted x_0"
163
- // if m_config.thresholding:
164
- // pred_original_sample = _threshold_sample(pred_original_sample)
165
- // elif m_config.clip_sample:
166
- // pred_original_sample = pred_original_sample.clamp(
167
- // -self.config.clip_sample_range, self.config.clip_sample_range
168
- // )
163
+ // TODO: support m_config.thresholding
164
+ OPENVINO_ASSERT (!m_config.thresholding ,
165
+ " Parameter 'thresholding' is not supported. Please, add support." );
166
+ // TODO: support m_config.clip_sample
167
+ OPENVINO_ASSERT (!m_config.clip_sample ,
168
+ " Parameter 'clip_sample' is not supported. Please, add support." );
169
169
170
170
// compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
171
171
std::vector<float > pred_sample_direction (pred_epsilon.size ());
@@ -197,5 +197,45 @@ void DDIMScheduler::scale_model_input(ov::Tensor sample, size_t inference_step)
197
197
return ;
198
198
}
199
199
200
+ void DDIMScheduler::rescale_zero_terminal_snr (std::vector<float >& betas) {
201
+ // Convert betas to alphas_bar_sqrt
202
+ std::vector<float > alphas, alphas_bar_sqrt;
203
+ for (float b : betas) {
204
+ alphas.push_back (1 .0f - b);
205
+ }
206
+
207
+ for (size_t i = 1 ; i <= alphas.size (); ++i) {
208
+ float alpha_cumprod =
209
+ std::accumulate (std::begin (alphas), std::begin (alphas) + i, 1.0 , std::multiplies<float >{});
210
+ alphas_bar_sqrt.push_back (std::sqrt (alpha_cumprod));
211
+ }
212
+
213
+ float alphas_bar_sqrt_0 = alphas_bar_sqrt[0 ];
214
+ float alphas_bar_sqrt_T = alphas_bar_sqrt[alphas_bar_sqrt.size () - 1 ];
215
+
216
+ for (float & x : alphas_bar_sqrt) {
217
+ // Shift so the last timestep is zero.
218
+ x = x - alphas_bar_sqrt_T;
219
+ // Scale so the first timestep is back to the old value.
220
+ x *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T);
221
+ // Revert sqrt
222
+ x = std::pow (x, 2 );
223
+ }
224
+
225
+ // Revert cumprod
226
+ std::vector<float > end = alphas_bar_sqrt, begin = alphas_bar_sqrt;
227
+ end.erase (end.begin ());
228
+ begin.pop_back ();
229
+
230
+ alphas[0 ] = alphas_bar_sqrt[0 ];
231
+ for (size_t i = 1 ; i < alphas.size (); ++i) {
232
+ alphas[i] = end[i - 1 ] / begin[i - 1 ];
233
+ }
234
+
235
+ std::transform (alphas.begin (), alphas.end (), betas.begin (), [](float x) {
236
+ return (1 - x);
237
+ });
238
+ }
239
+
200
240
} // namespace genai
201
241
} // namespace ov
0 commit comments