From de63877140cdf2e39147e1ca8bec8b6c1b9ba761 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Mon, 27 Jan 2025 14:57:55 +0000 Subject: [PATCH 1/4] Added generation config validation via Pydantic --- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- .../generation/configuration_utils.py | 666 +++++++++--------- src/transformers/utils/import_utils.py | 18 +- 4 files changed, 347 insertions(+), 341 deletions(-) diff --git a/setup.py b/setup.py index 0fce910aa89d..69ed29ac4144 100644 --- a/setup.py +++ b/setup.py @@ -146,7 +146,7 @@ "protobuf", "psutil", "pyyaml>=5.1", - "pydantic", + "pydantic>=2.0.0", "pytest>=7.2.0,<8.0.0", "pytest-asyncio", "pytest-timeout", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 6edc38351670..013027ec699e 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -52,7 +52,7 @@ "protobuf": "protobuf", "psutil": "psutil", "pyyaml": "pyyaml>=5.1", - "pydantic": "pydantic", + "pydantic": "pydantic>=2.0.0", "pytest": "pytest>=7.2.0,<8.0.0", "pytest-asyncio": "pytest-asyncio", "pytest-timeout": "pytest-timeout", diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 3f142ce77298..6555e86fa91c 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -20,7 +20,9 @@ import warnings from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, Optional, Tuple, Union + +from pydantic import BaseModel, ConfigDict, PrivateAttr, confloat, conint, model_validator from .. import __version__ from ..configuration_utils import PretrainedConfig @@ -95,7 +97,216 @@ class GenerationMode(ExplicitEnum): GROUP_BEAM_SEARCH = "group_beam_search" -class GenerationConfig(PushToHubMixin): +@dataclass +class CompileConfig(object): + """ + Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`. + See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments. + + Args: + fullgraph (`bool`, *optional*, defaults to `True`): + If `True`, requires that the whole forward be capturable in a single graph. + dynamic (`bool` or `None`, *optional*): + Whether to try to use dynamic shape graphs. + backend (`str` or `Callable`, *optional*, defaults to `"inductor"`): + Backend to be used. + mode (`str`, *optional*, defaults to `"reduce-overhead"`): + Controls balance between performance and overhead. + options (`dict`, *optional*): + A dictionary of options to pass to the backend. + + Examples: + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig + + >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b') + >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda() + + >>> # Automatic compile configuration, used with static cache + >>> compile_config = CompileConfig(dynamic=True) + + >>> # Generation with static cache and compile config + >>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda() + >>> output = model.generate( + ... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config + ... ) + >>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0] + ``` + """ + + fullgraph: bool = True + dynamic: Optional[bool] = None + backend: Union[str, Callable] = "inductor" + mode: str = "reduce-overhead" + options: Optional[dict] = None + + def to_dict(self) -> Dict[str, Any]: + """Serializes this instance to a Python dictionary.""" + return copy.deepcopy(self.__dict__) + + +@dataclass +class BaseWatermarkingConfig(ABC): + """Generic watermarking config""" + + @classmethod + def from_dict(cls, config_dict, **kwargs): + """ + Constructs a BaseWatermarkingConfig instance from a dictionary of parameters. + + Args: + config_dict (Dict[str, Any]): Dictionary containing configuration parameters. + **kwargs: Additional keyword arguments to override dictionary values. + + Returns: + BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary. + """ + config = cls(**config_dict) + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + return config + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + return output + + def __iter__(self): + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self): + """ + Serializes this instance to a JSON formatted string. + + Returns: + str: JSON formatted string representing the configuration instance. + """ + return json.dumps(self.__dict__, indent=2) + "\n" + + def update(self, **kwargs): + """ + Update the configuration attributes with new values. + + Args: + **kwargs: Keyword arguments representing configuration attributes and their new values. + """ + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + + @abstractmethod + def validate(self): ... + + @abstractmethod + def construct_processor(self, vocab_size): ... + + +@dataclass +class WatermarkingConfig(BaseWatermarkingConfig): + """ + Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`. + See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments. + + Accepts the following keys: + - greenlist_ratio (`float`): + Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25. + - bias (`float`): + Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0. + - hashing_key (`int`): + Hashing key used for watermarking. Defaults to 15485863 (the millionth prime). + - seeding_scheme (`str`): + Algorithm to use for watermarking. Accepts values: + - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper) + - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper) + The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash". + - context_width(`int`): + The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust. + """ + + def __init__( + self, + greenlist_ratio: Optional[float] = 0.25, + bias: Optional[float] = 2.0, + hashing_key: Optional[int] = 15485863, + seeding_scheme: Optional[str] = "lefthash", + context_width: Optional[int] = 1, + ): + self.greenlist_ratio = greenlist_ratio + self.bias = bias + self.hashing_key = hashing_key + self.seeding_scheme = seeding_scheme + self.context_width = context_width + + def validate(self): + watermark_missing_arg_msg = ( + "Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + if self.seeding_scheme not in ["selfhash", "lefthash"]: + raise ValueError( + watermark_missing_arg_msg.format( + key="seeding_scheme", + correct_value="[`selfhash`, `lefthash`]", + found_value=self.seeding_scheme, + ), + ) + if not 0.0 <= self.greenlist_ratio <= 1.0: + raise ValueError( + watermark_missing_arg_msg.format( + key="greenlist_ratio", + correct_value="in range between 0.0 and 1.0", + found_value=self.seeding_scheme, + ), + ) + if not self.context_width >= 1: + raise ValueError( + watermark_missing_arg_msg.format( + key="context_width", + correct_value="a positive integer", + found_value=self.context_width, + ), + ) + + def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor": + return WatermarkLogitsProcessor( + vocab_size=vocab_size, + device=device, + greenlist_ratio=self.greenlist_ratio, + bias=self.bias, + hashing_key=self.hashing_key, + seeding_scheme=self.seeding_scheme, + context_width=self.context_width, + ) + + +class GenerationConfig(BaseModel, PushToHubMixin): # no-format """ Class that holds a configuration for a generation task. A `generate` call supports the following generation methods @@ -278,9 +489,9 @@ class GenerationConfig(PushToHubMixin): A list of pairs of integers which indicates a mapping from generation indices to token indices that will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always be a token of index 123. - sequence_bias (`Dict[Tuple[int], float]`, *optional*)): - Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the - sequence being selected, while negative biases do the opposite. Check + sequence_bias (`List[List[List[int], float]]`, *optional*)): + List of pairs that contain a non-empty list of token ids and a float bias , e.g, `[[[32, 69], -1.7], [[92], 0.2]]`. + Positive biases increase the odds of the sequence being selected, while negative biases do the opposite. Check [`~generation.SequenceBiasLogitsProcessor`] for further documentation and examples. token_healing (`bool`, *optional*, defaults to `False`): Heal tail tokens of prompts by replacing them with their appropriate extensions. @@ -386,125 +597,116 @@ class GenerationConfig(PushToHubMixin): present in `generate`'s signature will be used in the model forward pass. """ - extra_output_flags = ("output_attentions", "output_hidden_states", "output_scores", "output_logits") - - def __init__(self, **kwargs): - # Parameters that control the length of the output - self.max_length = kwargs.pop("max_length", 20) - self.max_new_tokens = kwargs.pop("max_new_tokens", None) - self.min_length = kwargs.pop("min_length", 0) - self.min_new_tokens = kwargs.pop("min_new_tokens", None) - self.early_stopping = kwargs.pop("early_stopping", False) - self.max_time = kwargs.pop("max_time", None) - self.stop_strings = kwargs.pop("stop_strings", None) - - # Parameters that control the generation strategy used - self.do_sample = kwargs.pop("do_sample", False) - self.num_beams = kwargs.pop("num_beams", 1) - self.num_beam_groups = kwargs.pop("num_beam_groups", 1) - self.penalty_alpha = kwargs.pop("penalty_alpha", None) - self.dola_layers = kwargs.pop("dola_layers", None) - - # Parameters that control the cache - self.use_cache = kwargs.pop("use_cache", True) - self.cache_implementation = kwargs.pop("cache_implementation", None) - self.cache_config = kwargs.pop("cache_config", None) - if self.cache_implementation is not None and self.cache_implementation in CACHE_CONFIG_MAPPING: - cache_config_class = CACHE_CONFIG_MAPPING[self.cache_implementation] - if isinstance(self.cache_config, dict): - self.cache_config = cache_config_class.from_dict(self.cache_config) - self.return_legacy_cache = kwargs.pop("return_legacy_cache", None) - - # Parameters for manipulation of the model output logits - self.temperature = kwargs.pop("temperature", 1.0) - self.top_k = kwargs.pop("top_k", 50) - self.top_p = kwargs.pop("top_p", 1.0) - self.min_p = kwargs.pop("min_p", None) - self.typical_p = kwargs.pop("typical_p", 1.0) - self.epsilon_cutoff = kwargs.pop("epsilon_cutoff", 0.0) - self.eta_cutoff = kwargs.pop("eta_cutoff", 0.0) - self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) - self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) - self.encoder_repetition_penalty = kwargs.pop("encoder_repetition_penalty", 1.0) - self.length_penalty = kwargs.pop("length_penalty", 1.0) - self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) - self.bad_words_ids = kwargs.pop("bad_words_ids", None) - self.force_words_ids = kwargs.pop("force_words_ids", None) - self.renormalize_logits = kwargs.pop("renormalize_logits", False) - self.constraints = kwargs.pop("constraints", None) - self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) - self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) - self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) - self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None) - self.suppress_tokens = kwargs.pop("suppress_tokens", None) - self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) - self.forced_decoder_ids = kwargs.pop("forced_decoder_ids", None) - self.sequence_bias = kwargs.pop("sequence_bias", None) - self.token_healing = kwargs.pop("token_healing", False) - self.guidance_scale = kwargs.pop("guidance_scale", None) - self.low_memory = kwargs.pop("low_memory", None) - watermarking_config = kwargs.pop("watermarking_config", None) - if watermarking_config is None: - self.watermarking_config = None - elif isinstance(watermarking_config, BaseWatermarkingConfig): - self.watermarking_config = watermarking_config - else: - self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config) - - # Parameters that define the output variables of `generate` - self.num_return_sequences = kwargs.pop("num_return_sequences", 1) - self.output_attentions = kwargs.pop("output_attentions", False) - self.output_hidden_states = kwargs.pop("output_hidden_states", False) - self.output_scores = kwargs.pop("output_scores", False) - self.output_logits = kwargs.pop("output_logits", None) - self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) - - # Special tokens that can be used at generation time - self.pad_token_id = kwargs.pop("pad_token_id", None) - self.bos_token_id = kwargs.pop("bos_token_id", None) - self.eos_token_id = kwargs.pop("eos_token_id", None) - - # Generation parameters exclusive to encoder-decoder models - self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0) - self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) - - # Assistant generation - self.is_assistant = False - self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 20) - self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "constant") - self.assistant_confidence_threshold = kwargs.pop("assistant_confidence_threshold", 0.4) - self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) - self.max_matching_ngram_size = kwargs.pop("max_matching_ngram_size", None) - self.assistant_early_exit = kwargs.pop("assistant_early_exit", None) - ## assistant generation for different tokenizers, the windows size for assistant/target model - self.assistant_lookbehind = kwargs.pop("assistant_lookbehind", 10) - self.target_lookbehind = kwargs.pop("target_lookbehind", 10) - - # Performances - self.compile_config = kwargs.pop("compile_config", CompileConfig()) - - # Wild card - self.generation_kwargs = kwargs.pop("generation_kwargs", {}) - - # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub - # interface. - self._from_model_config = kwargs.pop("_from_model_config", False) - self._commit_hash = kwargs.pop("_commit_hash", None) - self.transformers_version = kwargs.pop("transformers_version", __version__) - - # Additional attributes without default values - if not self._from_model_config: - # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a - # model's default configuration file - for key, value in kwargs.items(): - try: - setattr(self, key, value) - except AttributeError as err: - logger.error(f"Can't set {key} with value {value} for {self}") - raise err - - # Validate the values of the attributes - self.validate(is_init=True) + extra_output_flags: ClassVar[Tuple[str]] = ( + "output_attentions", + "output_hidden_states", + "output_scores", + "output_logits", + ) + model_config = ConfigDict(extra="allow", strict=True, revalidate_instances="subclass-instances") + + # Parameters that control the length of the output + max_length: Optional[conint(ge=0)] = 20 + max_new_tokens: Optional[conint(gt=0)] = None + min_length: Optional[conint(ge=0)] = 0 + min_new_tokens: Optional[conint(ge=0)] = None + early_stopping: Optional[Union[bool, Literal["always"]]] = False + max_time: Optional[confloat(ge=0.0)] = None + stop_strings: Optional[Union[str, List[str]]] = None + + # Parameters that control the generation strategy used + do_sample: bool = False + num_beams: Optional[conint(ge=1)] = 1 + num_beam_groups: Optional[conint(ge=1)] = 1 + penalty_alpha: Optional[confloat(ge=0.0)] = None + dola_layers: Optional[Union[str, List[int]]] = None + + # Parameters that control the cache + use_cache: bool = True + cache_implementation: Optional[str] = None + cache_config: Optional[Dict] = None + if cache_implementation is not None and cache_implementation in CACHE_CONFIG_MAPPING: + cache_config_class = CACHE_CONFIG_MAPPING[cache_implementation] + if isinstance(cache_config, dict): + cache_config = cache_config_class.from_dict(cache_config) + return_legacy_cache: Optional[bool] = None + + # Parameters for manipulation of the model output logits + temperature: Optional[confloat(ge=0.0, le=2.0)] = 1.0 + top_k: Optional[conint(ge=0)] = 50 + top_p: Optional[confloat(ge=0.0, le=1.0)] = 1.0 + min_p: Optional[confloat(ge=0.0, le=1.0)] = None + typical_p: Optional[confloat(ge=0.0, le=1.0)] = 1.0 + epsilon_cutoff: Optional[confloat(ge=0.0, le=1.0)] = 0.0 + eta_cutoff: Optional[confloat(ge=0.0, le=1.0)] = 0.0 + diversity_penalty: Optional[confloat(ge=0.0)] = 0.0 + repetition_penalty: Optional[confloat(ge=1.0)] = 1.0 + encoder_repetition_penalty: Optional[confloat(ge=1.0)] = 1.0 + length_penalty: float = 1.0 + no_repeat_ngram_size: Optional[conint(ge=0)] = 0 + bad_words_ids: Optional[List[List[int]]] = None + force_words_ids: Optional[List[Union[List[int], List[List[int]]]]] = None + renormalize_logits: Optional[bool] = False + constraints: Optional[List[Dict]] = None + forced_bos_token_id: Optional[conint(ge=0)] = None + forced_eos_token_id: Optional[Union[conint(ge=0), List[conint(ge=0)]]] = None + remove_invalid_values: Optional[bool] = False + exponential_decay_length_penalty: Optional[Union[tuple, list]] = None + suppress_tokens: Optional[Union[int, List[int]]] = None + begin_suppress_tokens: Optional[List[int]] = None + forced_decoder_ids: Optional[List[List[int]]] = None + sequence_bias: Optional[Union[List[Union[List[int], float]], List[List[Union[List[int], float]]]]] = None + token_healing: Optional[bool] = False + guidance_scale: Optional[confloat(ge=1.0)] = None + low_memory: Optional[bool] = None + watermarking_config: Optional[Union[dict, "WatermarkingConfig", "SynthIDTextWatermarkingConfig"]] = None + if watermarking_config is None: + watermarking_config = None + elif isinstance(watermarking_config, BaseWatermarkingConfig): + watermarking_config = watermarking_config + else: + watermarking_config = WatermarkingConfig.from_dict(watermarking_config) + + # Parameters that define the output variables of `generate` + num_return_sequences: Optional[conint(ge=1)] = 1 + output_attentions: Optional[bool] = False + output_hidden_states: Optional[bool] = False + output_scores: Optional[bool] = False + output_logits: Optional[bool] = None + return_dict_in_generate: Optional[bool] = False + + # Special tokens that can be used at generation time + pad_token_id: Optional[int] = None + bos_token_id: Optional[conint(ge=0)] = None + eos_token_id: Optional[Union[conint(ge=0), List[conint(ge=0)]]] = None + + # Generation parameters exclusive to encoder-decoder models + encoder_no_repeat_ngram_size: Optional[conint(ge=0)] = 0 + decoder_start_token_id: Optional[Union[conint(ge=0), List[conint(ge=0)]]] = None + + # Assistant generation + is_assistant: bool = False + num_assistant_tokens: Optional[conint(ge=1)] = 20 + num_assistant_tokens_schedule: Optional[str] = "constant" + assistant_confidence_threshold: Optional[confloat(ge=0.0, le=1.0)] = 0.4 + prompt_lookup_num_tokens: Optional[conint(ge=0)] = None + max_matching_ngram_size: Optional[conint(ge=0)] = None + assistant_early_exit: Optional[conint(ge=0)] = None + ## assistant generation for different tokenizers, the windows size for assistant/target model + assistant_lookbehind: Optional[conint(ge=0)] = 10 + target_lookbehind: Optional[conint(ge=0)] = 10 + + # Performances + compile_config: Optional["CompileConfig"] = CompileConfig() + + # Wild card + generation_kwargs: Optional[Dict] = {} + + # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub + # interface. + _from_model_config = PrivateAttr(default=False) + _commit_hash = PrivateAttr(default=None) + transformers_version: str = __version__ def __hash__(self): return hash(self.to_json_string(ignore_metadata=True)) @@ -582,6 +784,7 @@ def get_generation_mode(self, assistant_model: Optional["PreTrainedModel"] = Non ) return generation_mode + @model_validator(mode="after") def validate(self, is_init=False): """ Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence @@ -595,11 +798,6 @@ def validate(self, is_init=False): Whether the validation is performed during the initialization of the instance. """ - # Validation of individual attributes - if self.early_stopping not in {True, False, "never"}: - raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") - if self.max_new_tokens is not None and self.max_new_tokens <= 0: - raise ValueError(f"`max_new_tokens` must be greater than 0, but is {self.max_new_tokens}.") if self.pad_token_id is not None and self.pad_token_id < 0: warnings.warn( f"`pad_token_id` should be positive but got {self.pad_token_id}. This will cause errors when batch " @@ -836,6 +1034,7 @@ def validate(self, is_init=False): f"Argument `{arg}` is not a valid argument of `GenerationConfig`. It should be passed to " "`generate()` (or a pipeline) directly." ) + return self def save_pretrained( self, @@ -1327,167 +1526,6 @@ def update(self, **kwargs): return unused_kwargs -@dataclass -class BaseWatermarkingConfig(ABC): - """Generic watermarking config""" - - @classmethod - def from_dict(cls, config_dict, **kwargs): - """ - Constructs a BaseWatermarkingConfig instance from a dictionary of parameters. - - Args: - config_dict (Dict[str, Any]): Dictionary containing configuration parameters. - **kwargs: Additional keyword arguments to override dictionary values. - - Returns: - BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary. - """ - config = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - return config - - def to_json_file(self, json_file_path: Union[str, os.PathLike]): - """ - Save this instance to a JSON file. - - Args: - json_file_path (Union[str, os.PathLike]): Path to the JSON file in which this configuration instance's parameters will be saved. - """ - with open(json_file_path, "w", encoding="utf-8") as writer: - config_dict = self.to_dict() - json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - - writer.write(json_string) - - def to_dict(self) -> Dict[str, Any]: - """ - Serializes this instance to a Python dictionary. - - Returns: - Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance. - """ - output = copy.deepcopy(self.__dict__) - return output - - def __iter__(self): - for attr, value in copy.deepcopy(self.__dict__).items(): - yield attr, value - - def __repr__(self): - return f"{self.__class__.__name__} {self.to_json_string()}" - - def to_json_string(self): - """ - Serializes this instance to a JSON formatted string. - - Returns: - str: JSON formatted string representing the configuration instance. - """ - return json.dumps(self.__dict__, indent=2) + "\n" - - def update(self, **kwargs): - """ - Update the configuration attributes with new values. - - Args: - **kwargs: Keyword arguments representing configuration attributes and their new values. - """ - for key, value in kwargs.items(): - if hasattr(self, key): - setattr(self, key, value) - - @abstractmethod - def validate(self): ... - - @abstractmethod - def construct_processor(self, vocab_size): ... - - -@dataclass -class WatermarkingConfig(BaseWatermarkingConfig): - """ - Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`. - See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments. - - Accepts the following keys: - - greenlist_ratio (`float`): - Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25. - - bias (`float`): - Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0. - - hashing_key (`int`): - Hashing key used for watermarking. Defaults to 15485863 (the millionth prime). - - seeding_scheme (`str`): - Algorithm to use for watermarking. Accepts values: - - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper) - - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper) - The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash". - - context_width(`int`): - The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust. - """ - - def __init__( - self, - greenlist_ratio: Optional[float] = 0.25, - bias: Optional[float] = 2.0, - hashing_key: Optional[int] = 15485863, - seeding_scheme: Optional[str] = "lefthash", - context_width: Optional[int] = 1, - ): - self.greenlist_ratio = greenlist_ratio - self.bias = bias - self.hashing_key = hashing_key - self.seeding_scheme = seeding_scheme - self.context_width = context_width - - def validate(self): - watermark_missing_arg_msg = ( - "Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` " - "but found {found_value}" - ) - if self.seeding_scheme not in ["selfhash", "lefthash"]: - raise ValueError( - watermark_missing_arg_msg.format( - key="seeding_scheme", - correct_value="[`selfhash`, `lefthash`]", - found_value=self.seeding_scheme, - ), - ) - if not 0.0 <= self.greenlist_ratio <= 1.0: - raise ValueError( - watermark_missing_arg_msg.format( - key="greenlist_ratio", - correct_value="in range between 0.0 and 1.0", - found_value=self.seeding_scheme, - ), - ) - if not self.context_width >= 1: - raise ValueError( - watermark_missing_arg_msg.format( - key="context_width", - correct_value="a positive integer", - found_value=self.context_width, - ), - ) - - def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor": - return WatermarkLogitsProcessor( - vocab_size=vocab_size, - device=device, - greenlist_ratio=self.greenlist_ratio, - bias=self.bias, - hashing_key=self.hashing_key, - seeding_scheme=self.seeding_scheme, - context_width=self.context_width, - ) - - @dataclass class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig): """ @@ -1576,51 +1614,3 @@ def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProces skip_first_ngram_calls=self.skip_first_ngram_calls, debug_mode=self.debug_mode, ) - - -@dataclass -class CompileConfig(object): - """ - Class that holds arguments relative to `torch.compile` behavior, when using automatic compilation in `generate`. - See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments. - - Args: - fullgraph (`bool`, *optional*, defaults to `True`): - If `True`, requires that the whole forward be capturable in a single graph. - dynamic (`bool` or `None`, *optional*): - Whether to try to use dynamic shape graphs. - backend (`str` or `Callable`, *optional*, defaults to `"inductor"`): - Backend to be used. - mode (`str`, *optional*, defaults to `"reduce-overhead"`): - Controls balance between performance and overhead. - options (`dict`, *optional*): - A dictionary of options to pass to the backend. - - Examples: - ```python - >>> from transformers import AutoModelForCausalLM, AutoTokenizer, CompileConfig - - >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b') - >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b').cuda() - - >>> # Automatic compile configuration, used with static cache - >>> compile_config = CompileConfig(dynamic=True) - - >>> # Generation with static cache and compile config - >>> input = tokenizer.encode("Hello there, how", return_tensors="pt").cuda() - >>> output = model.generate( - ... input, do_sample=False, max_new_tokens=300, cache_implementation="static", compile_config=compile_config - ... ) - >>> output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0] - ``` - """ - - fullgraph: bool = True - dynamic: Optional[bool] = None - backend: Union[str, Callable] = "inductor" - mode: str = "reduce-overhead" - options: Optional[dict] = None - - def to_dict(self) -> Dict[str, Any]: - """Serializes this instance to a Python dictionary.""" - return copy.deepcopy(self.__dict__) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ac07281b3d33..dc7dc43b53ff 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -94,7 +94,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ XLA_FSDPV2_MIN_VERSION = "2.2.0" HQQ_MIN_VERSION = "0.2.1" VPTQ_MIN_VERSION = "0.0.4" - +PYDANTIC_MIN_VERSION = "2.0.0" _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _apex_available = _is_package_available("apex") @@ -167,6 +167,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _pytesseract_available = _is_package_available("pytesseract") _pytest_available = _is_package_available("pytest") _pytorch_quantization_available = _is_package_available("pytorch_quantization") +_pydantic_available, _pydantic_version = _is_package_available("pydantic", return_version=True) _rjieba_available = _is_package_available("rjieba") _sacremoses_available = _is_package_available("sacremoses") _safetensors_available = _is_package_available("safetensors") @@ -1284,6 +1285,20 @@ def is_triton_available(): return _triton_available +def is_pydantic_available(min_version: str = PYDANTIC_MIN_VERSION): + return _pydantic_available and version.parse(_pydantic_version) >= version.parse(min_version) + + +# docstyle-ignore +PYDANTIC_IMPORT_ERROR = """ +{0} requires the pydantic library >= {PYDANTIC_MIN_VERSION} it was not found in your environment. +You can install or update it with: +``` +pip install -U pydantic +``` +Please note that you may need to restart your runtime after installation. +""" + # docstyle-ignore AV_IMPORT_ERROR = """ {0} requires the PyAv library but it was not found in your environment. You can install it with: @@ -1665,6 +1680,7 @@ def is_triton_available(): ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)), ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), + ("pydantic", (is_pydantic_available, PYDANTIC_IMPORT_ERROR)), ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)), ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)), ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), From 22aa9a7a828e4d249e374a18c437e656c8ddd6ad Mon Sep 17 00:00:00 2001 From: Manal ML Date: Tue, 28 Jan 2025 15:34:59 +0000 Subject: [PATCH 2/4] Fix GenerationConfig inheriting classes --- .../models/bark/generation_configuration_bark.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/bark/generation_configuration_bark.py b/src/transformers/models/bark/generation_configuration_bark.py index 036c9caa83ba..f928ec878523 100644 --- a/src/transformers/models/bark/generation_configuration_bark.py +++ b/src/transformers/models/bark/generation_configuration_bark.py @@ -15,7 +15,7 @@ """BARK model generation configuration""" import copy -from typing import Dict +from typing import ClassVar, Dict from ...generation.configuration_utils import GenerationConfig from ...utils import logging @@ -25,7 +25,7 @@ class BarkSemanticGenerationConfig(GenerationConfig): - model_type = "semantic" + model_type : ClassVar[str] = "semantic" def __init__( self, @@ -116,7 +116,7 @@ def __init__( class BarkCoarseGenerationConfig(GenerationConfig): - model_type = "coarse_acoustics" + model_type : ClassVar[str] = "coarse_acoustics" def __init__( self, @@ -196,7 +196,7 @@ def __init__( class BarkFineGenerationConfig(GenerationConfig): - model_type = "fine_acoustics" + model_type : ClassVar[str] = "fine_acoustics" def __init__( self, @@ -239,8 +239,8 @@ def validate(self, **kwargs): class BarkGenerationConfig(GenerationConfig): - model_type = "bark" - is_composition = True + model_type : ClassVar[str] = "bark" + is_composition : ClassVar[bool] = True # TODO (joao): nested from_dict From d2721ac32f0a927939f475e98b4184020f0fab78 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Thu, 13 Feb 2025 18:48:08 +0100 Subject: [PATCH 3/4] Update types --- src/transformers/generation/configuration_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 6555e86fa91c..f6410201c337 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -20,7 +20,7 @@ import warnings from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Final, Dict, List, Literal, Optional, Tuple, Union from pydantic import BaseModel, ConfigDict, PrivateAttr, confloat, conint, model_validator @@ -597,7 +597,7 @@ class GenerationConfig(BaseModel, PushToHubMixin): present in `generate`'s signature will be used in the model forward pass. """ - extra_output_flags: ClassVar[Tuple[str]] = ( + extra_output_flags: Final[Tuple[str]] = ( "output_attentions", "output_hidden_states", "output_scores", @@ -619,7 +619,7 @@ class GenerationConfig(BaseModel, PushToHubMixin): num_beams: Optional[conint(ge=1)] = 1 num_beam_groups: Optional[conint(ge=1)] = 1 penalty_alpha: Optional[confloat(ge=0.0)] = None - dola_layers: Optional[Union[str, List[int]]] = None + dola_layers: Optional[Union[Literal["low", "high"], List[int]]] = None # Parameters that control the cache use_cache: bool = True From e8c5f17ecdd9c35593aa30df0dbb3b728aad37c4 Mon Sep 17 00:00:00 2001 From: Manal ML Date: Fri, 21 Feb 2025 05:17:36 +0000 Subject: [PATCH 4/4] Fix constraints --- .../generation/configuration_utils.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index f6410201c337..efec6efbce3f 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -52,6 +52,7 @@ if is_torch_available(): from ..cache_utils import ( + CacheConfig, HQQQuantizedCache, HybridCache, MambaCache, @@ -606,7 +607,7 @@ class GenerationConfig(BaseModel, PushToHubMixin): model_config = ConfigDict(extra="allow", strict=True, revalidate_instances="subclass-instances") # Parameters that control the length of the output - max_length: Optional[conint(ge=0)] = 20 + max_length: Optional[conint(ge=1)] = 20 max_new_tokens: Optional[conint(gt=0)] = None min_length: Optional[conint(ge=0)] = 0 min_new_tokens: Optional[conint(ge=0)] = None @@ -624,7 +625,7 @@ class GenerationConfig(BaseModel, PushToHubMixin): # Parameters that control the cache use_cache: bool = True cache_implementation: Optional[str] = None - cache_config: Optional[Dict] = None + cache_config: Optional[Union[dict, "CacheConfig"]] = None if cache_implementation is not None and cache_implementation in CACHE_CONFIG_MAPPING: cache_config_class = CACHE_CONFIG_MAPPING[cache_implementation] if isinstance(cache_config, dict): @@ -632,16 +633,16 @@ class GenerationConfig(BaseModel, PushToHubMixin): return_legacy_cache: Optional[bool] = None # Parameters for manipulation of the model output logits - temperature: Optional[confloat(ge=0.0, le=2.0)] = 1.0 - top_k: Optional[conint(ge=0)] = 50 + temperature: Optional[confloat(ge=0.0)] = 1.0 + top_k: Optional[conint(ge=1)] = 50 top_p: Optional[confloat(ge=0.0, le=1.0)] = 1.0 min_p: Optional[confloat(ge=0.0, le=1.0)] = None typical_p: Optional[confloat(ge=0.0, le=1.0)] = 1.0 epsilon_cutoff: Optional[confloat(ge=0.0, le=1.0)] = 0.0 eta_cutoff: Optional[confloat(ge=0.0, le=1.0)] = 0.0 diversity_penalty: Optional[confloat(ge=0.0)] = 0.0 - repetition_penalty: Optional[confloat(ge=1.0)] = 1.0 - encoder_repetition_penalty: Optional[confloat(ge=1.0)] = 1.0 + repetition_penalty: Optional[confloat(ge=0.0)] = 1.0 + encoder_repetition_penalty: Optional[confloat(ge=0.0)] = 1.0 length_penalty: float = 1.0 no_repeat_ngram_size: Optional[conint(ge=0)] = 0 bad_words_ids: Optional[List[List[int]]] = None @@ -678,7 +679,7 @@ class GenerationConfig(BaseModel, PushToHubMixin): # Special tokens that can be used at generation time pad_token_id: Optional[int] = None bos_token_id: Optional[conint(ge=0)] = None - eos_token_id: Optional[Union[conint(ge=0), List[conint(ge=0)]]] = None + eos_token_id: Optional[Union[int, List[int]]] = None # Generation parameters exclusive to encoder-decoder models encoder_no_repeat_ngram_size: Optional[conint(ge=0)] = 0 @@ -689,12 +690,12 @@ class GenerationConfig(BaseModel, PushToHubMixin): num_assistant_tokens: Optional[conint(ge=1)] = 20 num_assistant_tokens_schedule: Optional[str] = "constant" assistant_confidence_threshold: Optional[confloat(ge=0.0, le=1.0)] = 0.4 - prompt_lookup_num_tokens: Optional[conint(ge=0)] = None - max_matching_ngram_size: Optional[conint(ge=0)] = None + prompt_lookup_num_tokens: Optional[conint(ge=1)] = None + max_matching_ngram_size: Optional[conint(ge=1)] = None assistant_early_exit: Optional[conint(ge=0)] = None ## assistant generation for different tokenizers, the windows size for assistant/target model - assistant_lookbehind: Optional[conint(ge=0)] = 10 - target_lookbehind: Optional[conint(ge=0)] = 10 + assistant_lookbehind: Optional[conint(ge=1)] = 10 + target_lookbehind: Optional[conint(ge=1)] = 10 # Performances compile_config: Optional["CompileConfig"] = CompileConfig()