Skip to content

Decorators

Miscellaneous decorators used throughout the library.

Functions

typecheck(func_=None, **types)

Decorator to enforce type checking for a function or method. There are two ways to call this: either explicitly passing argument types to the decorator, or letting it infer them using type annotations in the function that will be decorated. We allow both usage methods since older versions of Python lack type annotations, and also because I feel the annotation syntax can hurt readability.

Ported from htools to avoid extra dependency.

Parameters:

Name Type Description Default
func_ function

The function to decorate. When using decorator with manually-specified types, this is None. Underscore is used so that func can still be used as a valid keyword argument for the wrapped function.

None
types type

Optional way to specify variable types. Use standard types rather than importing from the typing library, as subscripted generics are not supported (e.g. typing.List[str] will not work; typing.List will but at that point there is no benefit over the standard list).

{}

Examples:

In the first example, we specify types directly in the decorator. Notice that they can be single types or tuples of types. You can choose to specify types for all arguments or just a subset.

@typecheck(x=float, y=(int, float), iters=int, verbose=bool)
def process(x, y, z, iters=5, verbose=True):
    print(f'z = {z}')
    for i in range(iters):
        if verbose: print(f'Iteration {i}...')
        x *= y
    return x
>>> process(3.1, 4.5, 0, 2.0)
TypeError: iters must be <class 'int'>, not <class 'float'>.
>>> process(3.1, 4, 'a', 1, False)
z = a
12.4

Alternatively, you can let the decorator infer types using annotations in the function that is to be decorated. The example below behaves equivalently to the explicit example shown above. Note that annotations regarding the returned value are ignored.

@typecheck
def process(x:float, y:(int, float), z, iters:int=5, verbose:bool=True):
    print(f'z = {z}')
    for i in range(iters):
        if verbose: print(f'Iteration {i}...')
        x *= y
    return x
>>> process(3.1, 4.5, 0, 2.0)
TypeError: iters must be <class 'int'>, not <class 'float'>.
>>> process(3.1, 4, 'a', 1, False)
z = a
12.4
Source code in lib/roboduck/decorators.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def typecheck(func_=None, **types):
    """Decorator to enforce type checking for a function or method. There are
    two ways to call this: either explicitly passing argument types to the
    decorator, or letting it infer them using type annotations in the function
    that will be decorated. We allow both usage methods since older
    versions of Python lack type annotations, and also because I feel the
    annotation syntax can hurt readability.

    Ported from [htools](https://github.com/hdmamin/htools) to avoid extra
    dependency.

    Parameters
    ----------
    func_ : function
        The function to decorate. When using decorator with
        manually-specified types, this is None. Underscore is used so that
        `func` can still be used as a valid keyword argument for the wrapped
        function.
    types : type
        Optional way to specify variable types. Use standard types rather than
        importing from the typing library, as subscripted generics are not
        supported (e.g. typing.List[str] will not work; typing.List will but at
        that point there is no benefit over the standard `list`).

    Examples
    --------
    In the first example, we specify types directly in the decorator. Notice
    that they can be single types or tuples of types. You can choose to
    specify types for all arguments or just a subset.

    ```
    @typecheck(x=float, y=(int, float), iters=int, verbose=bool)
    def process(x, y, z, iters=5, verbose=True):
        print(f'z = {z}')
        for i in range(iters):
            if verbose: print(f'Iteration {i}...')
            x *= y
        return x
    ```

    >>> process(3.1, 4.5, 0, 2.0)
    TypeError: iters must be <class 'int'>, not <class 'float'>.

    >>> process(3.1, 4, 'a', 1, False)
    z = a
    12.4

    Alternatively, you can let the decorator infer types using annotations
    in the function that is to be decorated. The example below behaves
    equivalently to the explicit example shown above. Note that annotations
    regarding the returned value are ignored.

    ```
    @typecheck
    def process(x:float, y:(int, float), z, iters:int=5, verbose:bool=True):
        print(f'z = {z}')
        for i in range(iters):
            if verbose: print(f'Iteration {i}...')
            x *= y
        return x
    ```

    >>> process(3.1, 4.5, 0, 2.0)
    TypeError: iters must be <class 'int'>, not <class 'float'>.

    >>> process(3.1, 4, 'a', 1, False)
    z = a
    12.4
    """
    # Case 1: Pass keyword args to decorator specifying types.
    if not func_:
        return partial(typecheck, **types)
    # Case 2: Infer types from annotations. Skip if Case 1 already occurred.
    elif not types:
        types = {k: v.annotation
                 for k, v in signature(func_).parameters.items()
                 if not v.annotation == Parameter.empty}

    @wraps(func_)
    def wrapper(*args, **kwargs):
        sig = signature(wrapper)
        try:
            fargs = sig.bind(*args, **kwargs).arguments
        except TypeError as e:
            # Default error message is not very helpful if we don't handle this
            # case separately.
            expected_positional = [name for name, p in sig.parameters.items()
                                   if 'positional' in str(p.kind).lower()]
            if args and not expected_positional:
                raise TypeError(
                    'Received positional arg(s) but expected none. Expected '
                    f'arguments: {list(sig.parameters)}'
                )
            else:
                raise e
        for k, v in types.items():
            if k in fargs and not isinstance(fargs[k], v):
                raise TypeError(
                    f'{k} must be {str(v)}, not {type(fargs[k])}.'
                )
        return func_(*args, **kwargs)
    return wrapper

add_kwargs(func, fields, hide_fields=(), strict=False)

Decorator that adds parameters into the signature and docstring of a function that accepts **kwargs.

Parameters:

Name Type Description Default
func function

Function to decorate.

required
fields list[str]

Names of params to insert into signature + docstring.

required
hide_fields list[str]

Names of params that are already in the function's signature that we want to hide. To use a non-empty value here, we must set strict=True and the param must have a default value, as this is what will be used in all subsequent calls.

()
strict bool

If true, we do two things: 1. On decorated function call, check that the user provided all expected arguments. 2. Enable the use of the hide_fields param.

False

Returns:

Type Description
function
Source code in lib/roboduck/decorators.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def add_kwargs(func, fields, hide_fields=(), strict=False):
    """Decorator that adds parameters into the signature and docstring of a
    function that accepts **kwargs.

    Parameters
    ----------
    func : function
        Function to decorate.
    fields : list[str]
        Names of params to insert into signature + docstring.
    hide_fields : list[str]
        Names of params that are *already* in the function's signature that
        we want to hide. To use a non-empty value here, we must set strict=True
        and the param must have a default value, as this is what will be used
        in all subsequent calls.
    strict : bool
        If true, we do two things:
        1. On decorated function call, check that the user provided all
        expected arguments.
        2. Enable the use of the `hide_fields` param.

    Returns
    -------
    function
    """
    # Hide_fields must have default values in existing function. They will not
    # show up in the new docstring and the user will not be able to pass in a
    # value when calling the new function - it will always use the default.
    # To set different defaults, you can pass in a partial rather than a
    # function as the first arg here.
    @wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)

    if hide_fields and not strict:
        raise ValueError(
            'You must set strict=True when providing one or more '
            'hide_fields. Otherwise the user can still pass in those args.'
        )
    sig = signature(wrapper)
    params_ = {k: v for k, v in sig.parameters.items()}

    # Remove any fields we want to hide.
    for field in hide_fields:
        if field not in params_:
            warnings.warn(f'No need to hide field {field} because it\'s not '
                          'in the existing function signature.')
        elif params_.pop(field).default == Parameter.empty:
            raise TypeError(
                f'Field "{field}" is not a valid hide_field because it has '
                'no default value in the original function.'
            )

    if getattr(params_.pop('kwargs', None), 'kind') != Parameter.VAR_KEYWORD:
        raise TypeError(f'Function {func} must accept **kwargs.')
    new_params = {
        field: Parameter(field, Parameter.KEYWORD_ONLY)
        for field in fields
    }
    overlap = set(new_params) & set(params_)
    if overlap:
        raise RuntimeError(
            f'Some of the kwargs you tried to inject into {func} already '
            'exist in its signature. This is not allowed because it\'s '
            'unclear how to resolve default values and parameter type.'
        )

    params_.update(new_params)
    wrapper.__signature__ = sig.replace(parameters=params_.values())
    if strict:
        # In practice langchain checks for this anyway if we ask for a
        # completion, but outside of that context we need typecheck
        # because otherwise we could provide no kwargs and _func wouldn't
        # complain. Just use generic type because we only care that a value is
        # provided.
        wrapper = typecheck(wrapper, **{f: object for f in fields})
    return wrapper

store_class_defaults(cls=None, attr_filter=None)

Class decorator that stores default values of class attributes (can be all or a subset). Default here refers to the value at class definition time. Mutable defaults should be okay since we deepcopy them, but are probably still riskier to use than immutable defaults.

Examples:

@store_class_defaults(attr_filter=lambda x: x.startswith('last_'))
class Foo:
    last_bar = 3
    last_baz = 'abc'
    other = True
>>> Foo._class_defaults

{'last_bar': 3, 'last_baz': 'abc'}

Or use the decorator without parentheses to store all values at definition time. This is usually unnecessary. If you do provide an attr_filter, it must be a named argument.

Foo.reset_class_vars() will reset all relevant class vars to their default values.

Source code in lib/roboduck/decorators.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
def store_class_defaults(cls=None, attr_filter=None):
    """Class decorator that stores default values of class attributes (can be
    all or a subset). Default here refers to the value at class definition
    time. Mutable defaults should be okay since we deepcopy them, but are
    probably still riskier to use than immutable defaults.

    Examples
    --------
    ```
    @store_class_defaults(attr_filter=lambda x: x.startswith('last_'))
    class Foo:
        last_bar = 3
        last_baz = 'abc'
        other = True
    ```

    >>> Foo._class_defaults

    {'last_bar': 3, 'last_baz': 'abc'}

    Or use the decorator without parentheses to store all values at definition
    time. This is usually unnecessary. If you do provide an attr_filter, it
    must be a named argument.

    Foo.reset_class_vars() will reset all relevant class vars to their
    default values.
    """
    if cls is None:
        return partial(store_class_defaults, attr_filter=attr_filter)
    if not isinstance(cls, type):
        raise TypeError(
            f'cls arg in store_class_defaults decorator has type {type(cls)} '
            f'but expected type `type`, i.e. a class. You may be passing in '
            f'an attr_filter as a positional arg which is not allowed - it '
            f'must be a named arg if provided.'
        )
    if not attr_filter:
        def attr_filter(x):
            # Usually returns True, just False for some magic methods that are
            # not easily filtered out otherwise.
            return not (x.startswith('__') and x.endswith('__'))
    defaults = {}
    for k, v in _classvars(cls).items():
        if attr_filter(k):
            defaults[k] = deepcopy(v)

    name = '_class_defaults'
    if hasattr(cls, name):
        raise AttributeError(
            f'Class {cls} already has attribute {name}. store_class_defaults '
            'decorator would overwrite that. Exiting.'
        )
    setattr(cls, name, defaults)

    @classmethod
    def reset_class_vars(cls):
        """Reset all default class attributes to their defaults."""
        for k, v in cls._class_defaults.items():
            try:
                setattr(cls, k, v)
            except Exception as e:
                warnings.warn(f'Could not reset class attribute {k} to its '
                              f'default value:\n\n{e}')

    meth_name = 'reset_class_vars'
    if hasattr(cls, meth_name):
        raise AttributeError(
            f'Class {cls} already has attribute {meth_name}. '
            f'store_class_defaults decorator would overwrite that. Exiting.'
        )
    setattr(cls, meth_name, reset_class_vars)
    return cls

add_docstring(func)

Add the docstring from another function/class to the decorated function/class.

Ported from htools to avoid extra dependency.

Parameters:

Name Type Description Default
func function

Function to decorate.

required

Examples:

@add_docstring(nn.Conv2d)
class ReflectionPaddedConv2d(nn.Module):
    # ...
Source code in lib/roboduck/decorators.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def add_docstring(func):
    """Add the docstring from another function/class to the decorated
    function/class.

    Ported from [htools](https://github.com/hdmamin/htools) to avoid extra
    dependency.

    Parameters
    ----------
    func : function
        Function to decorate.

    Examples
    --------
    ```
    @add_docstring(nn.Conv2d)
    class ReflectionPaddedConv2d(nn.Module):
        # ...
    ```
    """
    def decorator(new_func):
        new_func.__doc__ = f'{new_func.__doc__}\n\n{func.__doc__}'
        @wraps(new_func)
        def wrapper(*args, **kwargs):
            return new_func(*args, **kwargs)
        return wrapper
    return decorator