Skip to content

Commit e530cb7

Browse files
authored
DDIM: rescale_betas_zero_snr support, add asserts for params (openvinotoolkit#899)
`rescale_betas_zero_snr` param support for `bghira/pseudo-journey-v2` model. prompt: `cyberpunk cityscape like Tokyo New York with tall buildings at dusk golden hour cinematic lighting` ![image](https://github.com/user-attachments/assets/5ae8a2cd-797a-42f3-895c-5d438f69150f)
1 parent 6961842 commit e530cb7

File tree

2 files changed

+50
-9
lines changed

2 files changed

+50
-9
lines changed

src/cpp/src/text2image/schedulers/ddim.cpp

+49-9
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ DDIMScheduler::DDIMScheduler(const Config& scheduler_config)
6161
OPENVINO_THROW("'beta_schedule' must be one of 'LINEAR' or 'SCALED_LINEAR'. Please, add support of other types");
6262
}
6363

64-
// TODO: Rescale for zero SNR
65-
// if (m_config.rescale_betas_zero_snr) {betas = rescale_zero_terminal_snr(betas)}
64+
if (m_config.rescale_betas_zero_snr) {
65+
rescale_zero_terminal_snr(betas);
66+
}
6667

6768
std::transform(betas.begin(), betas.end(), std::back_inserter(alphas), [] (float b) { return 1.0f - b; });
6869

@@ -159,13 +160,12 @@ std::map<std::string, ov::Tensor> DDIMScheduler::step(ov::Tensor noise_pred, ov:
159160
}
160161
}
161162

162-
// TODO: Clip or threshold "predicted x_0"
163-
// if m_config.thresholding:
164-
// pred_original_sample = _threshold_sample(pred_original_sample)
165-
// elif m_config.clip_sample:
166-
// pred_original_sample = pred_original_sample.clamp(
167-
// -self.config.clip_sample_range, self.config.clip_sample_range
168-
// )
163+
// TODO: support m_config.thresholding
164+
OPENVINO_ASSERT(!m_config.thresholding,
165+
"Parameter 'thresholding' is not supported. Please, add support.");
166+
// TODO: support m_config.clip_sample
167+
OPENVINO_ASSERT(!m_config.clip_sample,
168+
"Parameter 'clip_sample' is not supported. Please, add support.");
169169

170170
// compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
171171
std::vector<float> pred_sample_direction(pred_epsilon.size());
@@ -197,5 +197,45 @@ void DDIMScheduler::scale_model_input(ov::Tensor sample, size_t inference_step)
197197
return;
198198
}
199199

200+
void DDIMScheduler::rescale_zero_terminal_snr(std::vector<float>& betas) {
201+
// Convert betas to alphas_bar_sqrt
202+
std::vector<float> alphas, alphas_bar_sqrt;
203+
for (float b : betas) {
204+
alphas.push_back(1.0f - b);
205+
}
206+
207+
for (size_t i = 1; i <= alphas.size(); ++i) {
208+
float alpha_cumprod =
209+
std::accumulate(std::begin(alphas), std::begin(alphas) + i, 1.0, std::multiplies<float>{});
210+
alphas_bar_sqrt.push_back(std::sqrt(alpha_cumprod));
211+
}
212+
213+
float alphas_bar_sqrt_0 = alphas_bar_sqrt[0];
214+
float alphas_bar_sqrt_T = alphas_bar_sqrt[alphas_bar_sqrt.size() - 1];
215+
216+
for (float& x : alphas_bar_sqrt) {
217+
// Shift so the last timestep is zero.
218+
x = x - alphas_bar_sqrt_T;
219+
// Scale so the first timestep is back to the old value.
220+
x *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T);
221+
// Revert sqrt
222+
x = std::pow(x, 2);
223+
}
224+
225+
// Revert cumprod
226+
std::vector<float> end = alphas_bar_sqrt, begin = alphas_bar_sqrt;
227+
end.erase(end.begin());
228+
begin.pop_back();
229+
230+
alphas[0] = alphas_bar_sqrt[0];
231+
for (size_t i = 1; i < alphas.size(); ++i) {
232+
alphas[i] = end[i - 1] / begin[i - 1];
233+
}
234+
235+
std::transform(alphas.begin(), alphas.end(), betas.begin(), [](float x) {
236+
return (1 - x);
237+
});
238+
}
239+
200240
} // namespace genai
201241
} // namespace ov

src/cpp/src/text2image/schedulers/ddim.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class DDIMScheduler : public IScheduler {
5353
size_t m_num_inference_steps;
5454
std::vector<int64_t> m_timesteps;
5555

56+
void rescale_zero_terminal_snr(std::vector<float>& betas);
5657
};
5758

5859
} // namespace genai

0 commit comments

Comments
 (0)