|
14 | 14 |
|
15 | 15 | import enum
|
16 | 16 | import importlib
|
| 17 | +import math |
17 | 18 | from copy import deepcopy
|
18 | 19 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
19 | 20 |
|
@@ -2862,17 +2863,31 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
|
2862 | 2863 | dtype=int_dtype,
|
2863 | 2864 | )
|
2864 | 2865 | if input_name == "code_b":
|
| 2866 | + # default value from https://github.com/deepseek-ai/Janus/blob/1daa72fa409002d40931bd7b36a9280362469ead/janus/models/vq_model.py#L42 |
| 2867 | + z_channels = getattr(self.normalized_config.config.params, "z_channels", 256) |
| 2868 | + patch_size = int(math.sqrt(z_channels)) |
| 2869 | + # default value from https://github.com/deepseek-ai/Janus/blob/1daa72fa409002d40931bd7b36a9280362469ead/generation_inference.py#L63 |
| 2870 | + generated_image_size = getattr(self.normalized_config.config.params, "img_size", 384) |
| 2871 | + latent_heigh = int(generated_image_size // patch_size) |
| 2872 | + latent_width = int(generated_image_size // patch_size) |
2865 | 2873 | return self.random_int_tensor(
|
2866 |
| - [self.batch_size, 576], |
| 2874 | + [self.batch_size, int(latent_heigh * latent_width)], |
2867 | 2875 | max_value=self.normalized_config.config.params.image_token_size,
|
2868 | 2876 | framework=framework,
|
2869 | 2877 | dtype=int_dtype,
|
2870 | 2878 | )
|
2871 | 2879 | if input_name == "image_shape":
|
2872 | 2880 | import torch
|
| 2881 | + # default value from https://github.com/deepseek-ai/Janus/blob/1daa72fa409002d40931bd7b36a9280362469ead/janus/models/vq_model.py#L42 |
| 2882 | + z_channels = getattr(self.normalized_config.config.params, "z_channels", 256) |
| 2883 | + patch_size = int(math.sqrt(z_channels)) |
| 2884 | + # default value from https://github.com/deepseek-ai/Janus/blob/1daa72fa409002d40931bd7b36a9280362469ead/generation_inference.py#L63 |
| 2885 | + generated_image_size = getattr(self.normalized_config.config.params, "img_size", 384) |
| 2886 | + latent_heigh = int(generated_image_size // patch_size) |
| 2887 | + latent_width = int(generated_image_size // patch_size) |
2873 | 2888 |
|
2874 | 2889 | return torch.tensor(
|
2875 |
| - [self.batch_size, self.normalized_config.config.params.n_embed, 24, 24], dtype=torch.int64 |
| 2890 | + [self.batch_size, self.normalized_config.config.params.n_embed, latent_heigh, latent_width], dtype=torch.int64 |
2876 | 2891 | )
|
2877 | 2892 | if input_name == "hidden_state":
|
2878 | 2893 | return self.random_float_tensor(
|
|
0 commit comments