数据流解析.

  1. 预处理后数据 dataset 的结构为 self.dataset = [trial_0, trial_1, ...], 每个元素形式是
trial = {
     'trial_id': trial_id,    # 单次实验的编号
     'spikes': active_spike,  # 动作电位数据
     'vel': vel               # 速度
 }
  • active_spike = List[bin], bin = List[(channel_id, offset),...] 或者 0.

  • vel = List[vel_bin], vel_bin = List[(v_x, v_y)...].


  1. 通过 Dataloader.pydataset 转化为 tensor. pad_collate_fn 输出了 padded_bin, bin_mask, spike_mask, vel.
  • padded_bin: (batch, max_bin, max_token, 2) 的张量, 2 来自于 token (channel_id, offset).

batch: 每次训练的批量大小; max_bin : 一个 batch 中最长的 bin 的长度; max_token : 一个 bin 中最多的 token 的数量.

  • bin_mask: (batch, max_bin) 的 bool 张量, 标记哪些 bin 是真实有效而非填充为 0 的.

  • spike_mask: (batch, max_bin, max_token) 的 bool 张量, 标记哪些 token 是真实有效的而非填充为 0 的.


  1. possm_model.py 进行 Embedding 语义升维.

Dataloader.py 输出的 spike (即 padded_bin) 为 (batch, max_bin, max_token, 2), 经过

channels, offsets = spike[..., 0], spike[..., 1]

拆分为 (batch, max_bin, max_token) 的 channelsoffsets.

self.emb = nn.Embedding(
    num_embeddings = config.num_channel, 
    embedding_dim = config.embed_dim
    )

通过 emb = self.emb(channels) 将 (batch, max_bin, max_token) 的 channels 转化为 (batch, max_bin, max_token, embed_dim) 的 emb.

embed_dim: 用于将 1 个 channel_id 升维为 embed_dim 维的矢量.


  1. cross_attention.pyemb 转换为 num_latents 个 latent 向量.
  • emb: (batch_size, max_bin, max_token, embed_dim)
  • offsets: (batch_size, max_bin, max_token) 的张量, 包含每个 token 的 bin 内相对时间.
  • spike_mask: (batch_size, max_bin, max_token) 的 bool 张量, 标记哪些 token 是非填充的
z = self.cross_attention(emb, offsets, spike_mask, self.freqs_cos, self.freqs_sin)
  • z: (batch_size, max_bin, num_latents, embed_dim) 的张量
latents = self.latent_query
xq = self.q_proj(latents).unsqueeze(0).unsqueeze(0).expand(batch_size, max_bin, -1, -1)
  • latents: 随机初始化的 (num_latents, embed_dim) 张量
  • xq: (batch_size, max_bin, num_latents, num_attention_heads * head_dim)

self.q_proj(latents) 通过 linear layer 将 latents 升维为 Query 向量 (num_latents, num_attention_heads * head_dim),

再通过 .unsqueeze(0).unsqueeze(0).expand(batch_size, max_bin, -1, -1) 扩展为 xq.

xk = self.k_proj(spike)
xv = self.v_proj(spike)
  • xk: (batch_size, max_bin, max_token, num_key_value_heads * head_dim) spike 通过 linear layer 升维为 Key 向量
  • xv: (batch_size, max_bin, max_token, num_key_value_heads * head_dim) spike 通过 linear layer 升维为 Value 向量

因此这里实际上是 self attention.

xq = xq.view(batch_size, max_bin, self.num_latents, self.n_local_heads, self.head_dim)
xk = xk.view(batch_size, max_bin, max_token, self.n_local_kv_heads, self.head_dim)
xv = xv.view(batch_size, max_bin, max_token, self.n_local_kv_heads, self.head_dim)
# self.n_local_heads = num_attention_heads
# self.n_local_kv_heads = num_key_value_heads

xq, xk, xv 进行 reshape. 其中

xk = RotaryEmbedding.apply_rotary_pos_emb(xq, xk, cos, sin, offsets, is_decoder=False)

xk 应用 Rotary Positional Embedding, offsets 是每个 token 在 bin 内的相对时间.

# 4. Transpose for Attention
xq = xq.transpose(2, 3) # 调换第 3 和 第 4 维
xk = repeat_kv(xk, self.n_rep).transpose(2, 3) 
xv = repeat_kv(xv, self.n_rep).transpose(2, 3)

为注意力计算做准备.

  • xq: (batch_size, max_bin, num_latents, num_attention_heads, head_dim) -> (batch_size, max_bin, num_attention_heads, num_latents, head_dim)

通过 repeat_kv()xkxvnum_key_value_heads 维度扩展为 num_attention_heads 维: (batch_size, max_bin, max_token, num_attention_heads, head_dim),

再通过 .transpose(2, 3): (batch_size, max_bin, num_attention_heads, max_token, head_dim)

注意力机制公式 $\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^{\dagger}}{\sqrt{d_{k}}})V$

# 5. Calculate Attention
scores = (xq @ xk.transpose(-2, -1)) / math.sqrt(self.head_dim) 

xq @ xk.transpose(-2, -1): 对最后两个维度进行矩阵乘法, 即

scores: (..., num_latents, head_dim) @ (..., head_dim, max_token) -> (batch_size, max_bin, num_attention_heads, num_latents, max_token)

# Apply mask if provided
mask = mask_spike.unsqueeze(2).unsqueeze(2) 
scores = scores.masked_fill(mask == False, float('-inf'))

mask: mask_spike (batch_size, max_bin, max_token) 通过 .unsqueeze(2).unsqueeze(2) 展开为 (batch_size, max_bin, 1, 1, max_token)

scores 中的填充 0 位使用 float('-inf'), 以确保 softmax 后这些位置的权重为 0.

scores = F.softmax(scores.float(), dim=-1).type_as(xq)

scores: (batch_size, max_bin, num_attention_heads, num_latents, max_token)

output = scores @ xv

output: (..., num_latents, max_token) @ (..., max_token, head_dim) -> (batch_size, max_bin, num_attention_heads, num_latents, head_dim), 这样就消除了因数据而异的 max_token 而得到定长.