GRU

gated recurrent unit.

重置门(reset gate) 更新门(update gate) $\in(0,1)$

广播机制: 对于不同形状的张量, 通过复制扩展使得维度相同.

  • 输入 $\mathbf{X}_{t}\in\mathbb{R}^{n\times d}$, 即 batch size(批量维度) 为 $n$, feature dimension(特征维度) 为 $d$

  • 隐状态 $\mathbf{H}_{t-1}\in\mathbb{R}^{n\times h}$, 即 batch size 为 $n$, 隐藏单元数 为 $h$

重置门 $\mathbf{R}_{t}\in\mathbb{R}^{n\times h}$ 和更新门 $\mathbf{Z}_{t}\in\mathbb{R}^{n\times h}$ 的计算公式: $$ \begin{aligned} \overset{n\times h}{\mathbf{R}_{t}} &= \sigma(\overset{n\times d}{\mathbf{X}_{t}}\cdot\overset{d\times h}{\mathbf{W}_{xr}} + \overset{n\times h}{\mathbf{H}_{t-1}}\cdot\overset{h\times h}{\mathbf{W}_{hr}} + \overset{1\times h}{\mathbf{b}_{r}}),\\ \overset{n\times h}{\mathbf{Z}_{t}} &= \sigma(\overset{n\times d}{\mathbf{X}_{t}}\cdot\overset{d\times h}{\mathbf{W}_{xz}} + \overset{n\times h}{\mathbf{H}_{t-1}}\cdot\overset{h\times h}{\mathbf{W}_{hz}} + \overset{1\times h}{\mathbf{b}_{z}})\\ \sigma(x) &= \frac{1}{1+e^{-x}}, \quad \text{sigmoid function} \end{aligned} $$

其中, $\mathbf{b}$ 会通过广播扩展为 $n\times h$ 的矩阵.

常规的隐状态更新: $$ \mathbf{H}_{t} = \phi(\mathbf{X}_{t}\cdot\mathbf{W}_{xh} + \mathbf{H}_{t-1}\cdot \mathbf{W}_{hh} + \mathbf{b}_{h}) $$

候选隐状态 $\widetilde{\mathbf{H}}_{t}\in\mathbb{R}^{n\times h}$ 计算:

$$ \overset{n\times h}{\widetilde{\mathbf{H}}_{t}} = \tanh{(\mathbf{X}_{t}\mathbf{W}_{xh} + (\mathbf{R}_{t}\odot \mathbf{H}_{t-1})\cdot \mathbf{W}_{hh} + \mathbf{b}_{h})} $$

Hadamard 积 $\odot$:

$$ \begin{bmatrix} a_{11} & {\color{red}{a_{12}}} & a_{13}\\ {\color{blue}{a_{21}}} & a_{22} & a_{23}\\ a_{31} & a_{32} & a_{33} \end{bmatrix}\odot\begin{bmatrix} b_{11} & {\color{red}{b_{12}}} & b_{13}\\ {\color{blue}{b_{21}}} & b_{22} & b_{23}\\ b_{31} & b_{32} & b_{33} \end{bmatrix} = \begin{bmatrix} a_{11}b_{11} & {\color{red}{a_{12}b_{12}}} & a_{13}b_{13}\\ {\color{blue}{a_{21}b_{21}}} & a_{22}b_{22} & a_{23}b_{23}\\ a_{31}b_{31} & a_{32}b_{32} & a_{33}b_{33} \end{bmatrix} $$

隐状态 $\mathbf{H}_{t}$ 计算:

$$ \overset{n\times h}{\mathbf{H}_{t}} = \overset{n\times h}{\mathbf{Z}_{t}} \odot \overset{n\times h}{\mathbf{H}_{t-1}} + (1-\overset{n\times h}{\mathbf{Z}_{t}})\odot \overset{n\times h}{\widetilde{\mathbf{H}}_{t}} $$

注意力机制

  • 查询(Query): 自主性提示
  • 键(Key): 非自主性提示
  • 值(Value): 感官输入

对于数据集 $\{(x_{i},y_{i})\}$, 如何学习其规律 $f$ 以预测任意输入 $x$ 的输出 $\hat{y} = f(x)$?

Nadaraya-Watson 核:

$$ f(x) = \sum_{i=1}^{n}\frac{K(x-x_{i})}{\sum_{j=1}^{n}K(x-x_{j})}y_{i} $$

或写作 注意力汇聚 形式:

$$ f(x) = \sum_{i=1}^{n}\alpha(x,x_{i})y_{i} $$

$\alpha(x,x_{i})$ 被称作 注意力权重, 那么 $f(x)$ 是对所有 $y_{i}$ 的加权平均.

比如考虑 Gaussian 核 $\begin{aligned}K(u) = \frac{1}{\sqrt{2\pi}}\exp{(-\frac{u^2}{2})}\end{aligned}$, 学习的 $f$:

$$ \begin{aligned} f(x) &= \sum_{i=1}^{n}\alpha(x,x_{i})y_{i}\\ &= \sum_{i=1}^{n}\frac{\exp{[-(x-x_{i})^{2}/2}]}{\sum_{j=1}^{n}\exp{[-(x-x_{j})^{2}/2]}}y_{i}\\ &= \sum_{i=1}\text{softmax}\left(-\frac{1}{2}(x-x_{i})^{2}\right)y_{i} \end{aligned} $$

也可引入参数 $w$ 调整核的宽度:

$$ \begin{aligned} f(x) &= \sum_{i=1}^{n}\alpha(x,x_{i})y_{i}\\ &= \sum_{i=1}^{n}\frac{\exp{[-w^{2}(x-x_{i})^{2}/2}]}{\sum_{j=1}^{n}\exp{[-w^{2}(x-x_{j})^{2}/2]}}y_{i}\\ &= \sum_{i=1}\text{softmax}\left(-\frac{1}{2}w^{2}(x-x_{i})^{2}\right)y_{i} \end{aligned} $$

Gaussian 核指数部分被称作 attention scoring function.


假设有 $m$ 个 key-value 对 $\{(\mathbf{k}_{i},\mathbf{v}_{i})\}$, 其中 $\mathbf{k}_{i}\in\mathbb{R}^{k}$, $\mathbf{v}_{i}\in\mathbb{R}^{v}$.

给定查询 $\mathbf{q}\in\mathbb{R}^{q}$, 则注意力汇聚函数为

$$ \begin{aligned} f(\mathbf{q},(\mathbf{k}_{1},\mathbf{v}_{1}),\cdots,(\mathbf{k}_{m},\mathbf{v}_{m})) &= \sum_{i=1}^{m}\alpha(\mathbf{q},\mathbf{k_{i}})\mathbf{v}_{i} \in \mathbb{R}^{v}\\ \alpha(\mathbf{q},\mathbf{k}_{i}) = \text{softmax}[a(\mathbf{q},\mathbf{k}_{i})] &= \frac{\exp{[a(\mathbf{q},\mathbf{k}_{i})]}}{\sum_{j=1}^{m}\exp{[a(\mathbf{q},\mathbf{k}_{j})]}}\in\mathbb{R} \end{aligned} $$

$a$ 即为注意力评分函数.


Masked softmax operation(掩码 softmax 操作): 仅将有意义的键值对纳入到计算中.

比如在语句处理中, 指定最长序列长度, 从而在计算 softmax 时过滤超过范围的部分(超出指定长度的位置均设为 $0$).


加性注意力: Query 和 Key 长度不同($q\neq k$). 注意力评分函数定义为

$$ a(\mathbf{q},\mathbf{k}) = \overset{1\times h}{\mathbf{w}_{v}^{\top}}\tanh{(\overset{h\times q}{\mathbf{W}_{q}}\cdot\overset{q\times 1}{\mathbf{q}} + \overset{h\times k}{\mathbf{W}_{k}}\cdot\overset{k\times 1}{\mathbf{k}})} $$


缩放点积注意力: Query 和 Key 长度都为 $d$.

若 Q 和 K 的元素均为独立随机变量, 且都是零均值, 单位方差. 那么 Q 与 K 的点积将是零均值, 方差为 $d$. 为了方差仍未 1, 将点积归一化, 则形成 scaled dot-product attention:

$$ a(\mathbf{q},\mathbf{k}) = \frac{\mathbf{q}^{\top}\mathbf{k}}{\sqrt{d}} $$

在批量处理时, 则学习到的注意力汇聚函数为

$$ \text{softmax}\left(\frac{\overset{n\times d}{\mathbf{Q}}\cdot\overset{d\times m}{\mathbf{K}^{\top}}}{\sqrt{d}}\right)\overset{m\times v}{\mathbf{V}} $$


自注意力: Value = Key. 其注意力汇聚函数会输出等长序列:

$$ \mathbf{y}_{i} = f(\mathbf{x}_{i},(\mathbf{x}_{1},\mathbf{x}_{1}),\cdots, (\mathbf{x}_{n},\mathbf{x}_{n}))\in \mathbb{R}^{d} $$


位置编码: 自注意力是并行计算的, 因此对顺序是无知的.

$$ a_{m,n} = \frac{\exp{\left(\frac{q_{m}^{\top}k_{n}}{\sqrt{d}}\right)}}{\sum_{j=1}^{N}\exp{\left(\frac{q_{m}^{\top}k_{j}}{\sqrt{d}}\right)}} $$

如果任意地打乱 $x$ 顺序, 对应的 $q$ 和 $k$ 也会对应地打乱, 但是 $q^{\top}k$ 是仍然不变的, 这就是对 token 顺序和间距不敏感的表现.

因此, 我们希望对词语引入有关位置的信息, 即 $q_{m}\rightarrow f(q,m)$, $k_{n}\rightarrow f(k,n)$, 于是注意力评分函数变为

$$ a_{m,n} = \frac{\exp{\left(\frac{f(q,m)^{\top}f(k,n)}{\sqrt{d}}\right)}}{\sum_{j=1}^{N}\exp{\left(\frac{f(q,m)^{\top}f(k,j)}{\sqrt{d}}\right)}} $$

  1. 基于三角函数的 固定位置编码

若有一个输入为 $\mathbf{X}\in \mathbb{R}^{n\times d}$, 其含有 $n$ 个词元, 每个词元被嵌入表达为一个 $d$ 维的向量. 那么定义位置嵌入矩阵 $\mathbf{P}\in\mathbb{R}^{n\times d}$:

$$ \begin{aligned} p_{i,2j} &= \sin{\left(\frac{i}{10000^{2j/d}}\right)}\\ p_{i,2j+1} &= \cos{\left(\frac{i}{10000^{2j/d}}\right)} \end{aligned} $$

然后将输出 $\mathbf{X} + \mathbf{P}$ 作为后续计算的输入.

  1. 旋转位置编码 RoPE(Rotary Position Embedding): 将一个向量旋转一个和位置有关的角度

RoPE 希望将 $q_{m}^{\top}k_{n}$ 携带上相对位置 $(m-n)$ 的信息, 即寻找函数 $f$ 使得

$$ f(q,m)\cdot f(k,n) = g(q,k,m-n) $$

而旋转矩阵满足这个性质.

假设 $q\in\mathbb{R}^{2}$, 则

$$ f(q,m) = R_{m}q = \begin{bmatrix} \cos{m\theta} & -\sin{m\theta}\\ \sin{m\theta} & \cos{m\theta} \end{bmatrix}\begin{bmatrix} q_{1}\\ q_{2} \end{bmatrix} $$

注意到

$$ \begin{aligned} q_{m}^{\top}k_{n} &= f(q,m)^{\top}f(k,n) = (R_{m}q)^{\top}\cdot (R_{n}k) = q^{T}R_{m}^{\top}R_{n}k\\ &= q^{\top}\begin{bmatrix} \cos{m\theta} & -\sin{m\theta}\\ \sin{m\theta} & \cos{m\theta} \end{bmatrix}^{\top}\begin{bmatrix} \cos{n\theta} & -\sin{n\theta}\\ \sin{n\theta} & \cos{n\theta} \end{bmatrix}k\\ &= q^{\top}\begin{bmatrix} \cos{(n-m)\theta} & -\sin{(n-m)\theta}\\ \sin{(n-m)\theta} & \cos{(n-m)\theta} \end{bmatrix}k\\ &= q^{\top}R_{n-m}k \end{aligned} $$

若要推广至更高维度, 可将 $d$ 维向量元素两两一组, 从而获得高维向量的旋转:

$$ \begin{aligned} \begin{bmatrix} \cos{m\theta} & -\sin{m\theta} & 0 & 0 & \cdots & 0 & 0\\ \sin{m\theta} & \cos{m\theta} & 0 & 0 & \cdots & 0 & 0\\ 0 & 0 & \cos{m\theta} & -\sin{m\theta} & \cdots & 0 & 0\\ 0 & 0 & \sin{m\theta} & \cos{m\theta} & \cdots & 0 & 0\\ \vdots & \vdots & \vdots & \vdots & \ddots & \vdots & \vdots\\ 0 & 0 & 0 & 0 & \cdots & \cos{m\theta} & -\sin{m\theta}\\ 0 & 0 & 0 & 0 & \cdots & \sin{m\theta} & \cos{m\theta} \end{bmatrix}\begin{bmatrix} q_{1}\\ q_{2}\\ q_{3}\\ q_{4}\\ \vdots\\ q_{d-1}\\ q_{d} \end{bmatrix} \end{aligned} $$

若将 $\theta$ 替换为与维度有关的 $\theta_{j} = 10000^{-2j/d}$, 则可得到最终的 RoPE 位置编码方式.

数据读取

如何将 50ms 的 spikes 数据转为一个 $D$ 维的 latent 向量?

$$ \mathbf{x}_{n} = \text{RoPE}[\text{UnitEmb}(u_{n}),\Delta t_{n}] $$

其中 $u_{n}$ 为第 $n$ 个通道(channel)的编号, $\Delta t_{n}$ 为该通道内 spikes 的相对时间(相对于各 bin 起始时间).

# Model.py
self.emb = nn.Embedding(
    num_embeddings=meta_data["num_channel"],
    embedding_dim=config.embed_dim
)

解码 & 输出

最近 $k$ 个隐状态(State-Space Model 或者 Gated Recurrent Unit) 作为元素, 形成张量

$$ \mathbf{H} = [\mathbf{h}_{t-k+1}, \mathbf{h}_{t-k+2}, \cdots, \mathbf{h}_{t}]\in\mathbb{R}^{k\times h} $$

取一个时间块(bin) 内的 $T_{c}$ 个点.

Key & Value:

$$ \mathbf{K} = \mathbf{H}\cdot\mathbf{W}_{k},\quad \mathbf{V} = \mathbf{H}\cdot\mathbf{W}_{v} $$