tiny_diff.modules.layer_factory.LayerFactory

class tiny_diff.modules.layer_factory.LayerFactory(cls: type[torch.nn.modules.module.Module], args: list[Any] | None = None, kwargs: dict[str, Any] | None = None)

Bases: object

Factory class for non linearity layers.

Parameters:
  • cls – non linearity class.

  • init_args – positional args for cls.

  • init_kwargs – keyword args for cls.

__init__(cls: type[torch.nn.modules.module.Module], args: list[Any] | None = None, kwargs: dict[str, Any] | None = None) None

Methods

__init__(cls[, args, kwargs])

layer(**kwargs)

Instantiated object.