|
1 |
| -import os |
2 |
| -import requests |
3 |
| -import yaml |
4 |
| -from functools import reduce |
5 |
| -from huggingface_hub import hf_hub_download |
6 |
| -from typing import Any, Optional, Union |
7 |
| - |
8 |
| - |
9 |
| -class LpotConfig: |
10 |
| - |
11 |
| - def __init__( |
12 |
| - self, |
13 |
| - config_path: str, |
14 |
| - save_path: Optional[str] = None, |
15 |
| - overwrite: Optional[bool] = False, |
16 |
| - ): |
17 |
| - """ |
18 |
| - Args: |
19 |
| - config_path (:obj:`str`): |
20 |
| - Path to the YAML configuration file used to control the tuning behavior. |
21 |
| - save_path (:obj:`str`, `optional`): |
22 |
| - Path used to save the configuration file. |
23 |
| - overwrite (:obj:`bool`, `optional`): |
24 |
| - Whether or not overwrite the configuration file when the latter is modified and saved. |
25 |
| - Returns: |
26 |
| - config: LpotConfig object. |
27 |
| - """ |
28 |
| - |
29 |
| - self.path = config_path |
30 |
| - self.config = self._read_config() |
31 |
| - self.save_path = save_path |
32 |
| - self.overwrite = overwrite |
33 |
| - |
34 |
| - def _read_config(self): |
35 |
| - with open(self.path, 'r') as f: |
36 |
| - try: |
37 |
| - config = yaml.safe_load(f) |
38 |
| - except yaml.YAMLError as exc: |
39 |
| - print(exc) |
40 |
| - return config |
41 |
| - |
42 |
| - def get_config(self, keys: str): |
43 |
| - return reduce(lambda d, key: d.get(key) if d else None, keys.split("."), self.config) |
44 |
| - |
45 |
| - def set_config(self, keys: str, value: Any): |
46 |
| - d = self.config |
47 |
| - keys = keys.split('.') |
48 |
| - for key in keys[:-1]: |
49 |
| - d = d.setdefault(key, {}) |
50 |
| - d[keys[-1]] = value |
51 |
| - self._save_pretrained() |
52 |
| - |
53 |
| - def _save_pretrained(self): |
54 |
| - if self.save_path is None and not self.overwrite: |
55 |
| - raise ValueError("Needs either path or overwrite set to True.") |
56 |
| - |
57 |
| - self.path = self.save_path if self.save_path is not None else self.path |
58 |
| - with open(self.path, "w") as f: |
59 |
| - yaml.dump(self.config, f) |
60 |
| - |
61 |
| - @classmethod |
62 |
| - def from_pretrained( |
63 |
| - cls, |
64 |
| - config_name_or_path: Union[str, os.PathLike], |
65 |
| - config_name: str, |
66 |
| - cache_dir: Optional[Union[str, os.PathLike]] = None, |
67 |
| - **config_kwargs |
68 |
| - ): |
69 |
| - """ |
70 |
| - Instantiate a LpotConfig object from a configuration file which can either be hosted on |
71 |
| - huggingface.co or from a local directory path. |
72 |
| -
|
73 |
| - Args: |
74 |
| - config_name_or_path (:obj:`Union[str, os.PathLike]`): |
75 |
| - Repository name in the Hub or path to a local directory containing the configuration file. |
76 |
| - config_name (:obj:`str`): |
77 |
| - Name of the configuration file. |
78 |
| - cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): |
79 |
| - Path to a directory in which a downloaded configuration should be cached if the standard cache should |
80 |
| - not be used. |
81 |
| - config_kwargs (:obj:`Dict`, `optional`): |
82 |
| - config_kwargs will be passed to the LpotConfig object during initialization. |
83 |
| - Returns: |
84 |
| - config: LpotConfig object. |
85 |
| - """ |
86 |
| - |
87 |
| - revision = None |
88 |
| - if len(config_name_or_path.split("@")) == 2: |
89 |
| - config_name_or_path, revision = config_name_or_path.split("@") |
90 |
| - |
91 |
| - if os.path.isdir(config_name_or_path) and config_name in os.listdir(config_name_or_path): |
92 |
| - config_file = os.path.join(config_name_or_path, config_name) |
93 |
| - else: |
94 |
| - try: |
95 |
| - config_file = hf_hub_download( |
96 |
| - repo_id=config_name_or_path, |
97 |
| - filename=config_name, |
98 |
| - revision=revision, |
99 |
| - cache_dir=cache_dir, |
100 |
| - ) |
101 |
| - except requests.exceptions.RequestException: |
102 |
| - raise ValueError(f"{config_name} NOT FOUND in HuggingFace Hub") |
103 |
| - |
104 |
| - config = cls(config_file, **config_kwargs) |
105 |
| - return config |
106 |
| -
|
| 1 | +import os |
| 2 | +import requests |
| 3 | +import yaml |
| 4 | +from functools import reduce |
| 5 | +from huggingface_hub import hf_hub_download |
| 6 | +from typing import Any, Optional, Union |
| 7 | + |
| 8 | + |
| 9 | +class LpotConfig: |
| 10 | + |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + config_path: str, |
| 14 | + save_path: Optional[str] = None, |
| 15 | + overwrite: Optional[bool] = False, |
| 16 | + ): |
| 17 | + """ |
| 18 | + Args: |
| 19 | + config_path (:obj:`str`): |
| 20 | + Path to the YAML configuration file used to control the tuning behavior. |
| 21 | + save_path (:obj:`str`, `optional`): |
| 22 | + Path used to save the configuration file. |
| 23 | + overwrite (:obj:`bool`, `optional`): |
| 24 | + Whether or not overwrite the configuration file when the latter is modified and saved. |
| 25 | + Returns: |
| 26 | + config: LpotConfig object. |
| 27 | + """ |
| 28 | + |
| 29 | + self.path = config_path |
| 30 | + self.config = self._read_config() |
| 31 | + self.save_path = save_path |
| 32 | + self.overwrite = overwrite |
| 33 | + |
| 34 | + def _read_config(self): |
| 35 | + with open(self.path, 'r') as f: |
| 36 | + try: |
| 37 | + config = yaml.safe_load(f) |
| 38 | + except yaml.YAMLError as exc: |
| 39 | + print(exc) |
| 40 | + return config |
| 41 | + |
| 42 | + def get_config(self, keys: str): |
| 43 | + return reduce(lambda d, key: d.get(key) if d else None, keys.split("."), self.config) |
| 44 | + |
| 45 | + def set_config(self, keys: str, value: Any): |
| 46 | + d = self.config |
| 47 | + keys = keys.split('.') |
| 48 | + for key in keys[:-1]: |
| 49 | + d = d.setdefault(key, {}) |
| 50 | + d[keys[-1]] = value |
| 51 | + self._save_pretrained() |
| 52 | + |
| 53 | + def _save_pretrained(self): |
| 54 | + if self.save_path is None and not self.overwrite: |
| 55 | + raise ValueError("Needs either path or overwrite set to True.") |
| 56 | + |
| 57 | + self.path = self.save_path if self.save_path is not None else self.path |
| 58 | + with open(self.path, "w") as f: |
| 59 | + yaml.dump(self.config, f) |
| 60 | + |
| 61 | + @classmethod |
| 62 | + def from_pretrained( |
| 63 | + cls, |
| 64 | + config_name_or_path: Union[str, os.PathLike], |
| 65 | + config_name: str, |
| 66 | + cache_dir: Optional[Union[str, os.PathLike]] = None, |
| 67 | + **config_kwargs |
| 68 | + ): |
| 69 | + """ |
| 70 | + Instantiate a LpotConfig object from a configuration file which can either be hosted on |
| 71 | + huggingface.co or from a local directory path. |
| 72 | +
|
| 73 | + Args: |
| 74 | + config_name_or_path (:obj:`Union[str, os.PathLike]`): |
| 75 | + Repository name in the Hub or path to a local directory containing the configuration file. |
| 76 | + config_name (:obj:`str`): |
| 77 | + Name of the configuration file. |
| 78 | + cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): |
| 79 | + Path to a directory in which a downloaded configuration should be cached if the standard cache should |
| 80 | + not be used. |
| 81 | + config_kwargs (:obj:`Dict`, `optional`): |
| 82 | + config_kwargs will be passed to the LpotConfig object during initialization. |
| 83 | + Returns: |
| 84 | + config: LpotConfig object. |
| 85 | + """ |
| 86 | + |
| 87 | + revision = None |
| 88 | + if len(config_name_or_path.split("@")) == 2: |
| 89 | + config_name_or_path, revision = config_name_or_path.split("@") |
| 90 | + |
| 91 | + if os.path.isdir(config_name_or_path) and config_name in os.listdir(config_name_or_path): |
| 92 | + config_file = os.path.join(config_name_or_path, config_name) |
| 93 | + else: |
| 94 | + try: |
| 95 | + config_file = hf_hub_download( |
| 96 | + repo_id=config_name_or_path, |
| 97 | + filename=config_name, |
| 98 | + revision=revision, |
| 99 | + cache_dir=cache_dir, |
| 100 | + ) |
| 101 | + except requests.exceptions.RequestException: |
| 102 | + raise ValueError(f"{config_name} NOT FOUND in HuggingFace Hub") |
| 103 | + |
| 104 | + config = cls(config_file, **config_kwargs) |
| 105 | + return config |
| 106 | + |
0 commit comments