|
19 | 19 |
|
20 | 20 | import re
|
21 | 21 | from abc import ABCMeta
|
22 |
| -from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, Union |
| 22 | +from inspect import getsource |
| 23 | +from typing import ( |
| 24 | + TYPE_CHECKING, |
| 25 | + Any, |
| 26 | + Callable, |
| 27 | + Iterable, |
| 28 | + Iterator, |
| 29 | + Sequence, |
| 30 | + TypeVar, |
| 31 | + Union, |
| 32 | +) |
| 33 | + |
| 34 | +from aenum import Enum |
| 35 | + |
| 36 | +# This is needed because for some reason pyright does not understand that Enum |
| 37 | +# is a class (probably because Enum does weird things to the Enum class) |
| 38 | +if TYPE_CHECKING: |
| 39 | + from enum import Enum |
23 | 40 |
|
24 | 41 | import numpy as np
|
25 | 42 | import numpy.typing as npt
|
@@ -215,3 +232,35 @@ def __init__(self, func: Callable[..., Any]):
|
215 | 232 |
|
216 | 233 | def __get__(self, instance: object, owner: object):
|
217 | 234 | return self.fget(owner)
|
| 235 | + |
| 236 | + |
| 237 | +def _get_doc(enum: type[Any], member: str): |
| 238 | + src = getsource(enum) |
| 239 | + member_pointer = src.find(member) |
| 240 | + docstr_start = member_pointer + src[member_pointer:].find('"""') + 3 |
| 241 | + docstr_end = docstr_start + src[docstr_start:].find('"""') |
| 242 | + return src[docstr_start:docstr_end] |
| 243 | + |
| 244 | + |
| 245 | +class MessageEnum(Enum): |
| 246 | + """Enum subclass allowing you to access the docstring of the members of your |
| 247 | + enum through the ``message`` property. |
| 248 | +
|
| 249 | + Example: |
| 250 | + >>> class A(MessageEnum): |
| 251 | + ... '''an enum''' |
| 252 | + ... VALUE1 = auto() |
| 253 | + ... '''member VALUE1''' |
| 254 | + ... VALUE2 = auto() |
| 255 | + ... '''member VALUE2''' |
| 256 | + >>> A.VALUE1.message |
| 257 | + 'member VALUE2' |
| 258 | +
|
| 259 | + """ |
| 260 | + |
| 261 | + message: str |
| 262 | + |
| 263 | + def __init__(self, *args: Any, **kwds: dict[str, Any]) -> None: |
| 264 | + super().__init__(*args, **kwds) |
| 265 | + for member in type(self).__members__: |
| 266 | + type(self).__members__[member].message = _get_doc(type(self), member) |
0 commit comments