tiny_diff.modules.attention

Classes

AttnFlatten([start_dim, end_dim])

Flattens a 4D BCHW tensor for visual attention.

CrossVisualAttention(context_dim, **kwargs)

Causal Visual Attention.

SelfVisualAttention(*args, **kwargs)

Self Visual attention.

VisualAttention(channels[, kv_channels, ...])

Multihead attention for visual transformers.