diff --git a/generate.py b/generate.py index 42210a5a..1f6c5e69 100755 --- a/generate.py +++ b/generate.py @@ -33,7 +33,7 @@ def generate_images(network_pkl, seeds, truncation_psi, outdir, class_idx, dlate if dlatents_npz is not None: print(f'Generating images from dlatents file "{dlatents_npz}"') dlatents = np.load(dlatents_npz)['dlatents'] - assert dlatents.shape[1:] == (18, 512) # [N, 18, 512] + assert dlatents.shape[-1] == 512 # [N, M, 512] imgs = Gs.components.synthesis.run(dlatents, output_transform=dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)) for i, img in enumerate(imgs): fname = f'{outdir}/dlatent{i:02d}.png'