4 Commits

Author SHA1 Message Date
  fangwei123456 ec09d1874c add the SimpleBaseNode 2 years ago
  fangwei123456 73e67fb350 add the SimpleBaseNode 2 years ago
  fangwei123456 024edfb196
Merge pull request #470 from pkuxmq/master 2 years ago
  Mingqing Xiao 4b12a388fc add DSR, OTTT, SLTT training modules 2 years ago
4 changed files with 1308 additions and 0 deletions
Split View
  1. +80
    -0
      spikingjelly/activation_based/functional.py
  2. +227
    -0
      spikingjelly/activation_based/layer.py
  3. +251
    -0
      spikingjelly/activation_based/model/spiking_vggws_ottt.py
  4. +750
    -0
      spikingjelly/activation_based/neuron.py

+ 80
- 0
spikingjelly/activation_based/functional.py View File

@@ -1277,4 +1277,84 @@ def fptt_online_training(model: nn.Module, optimizer: torch.optim.Optimizer, x_s



def ottt_online_training(model: nn.Module, optimizer: torch.optim.Optimizer, x_seq: torch.Tensor, target_seq: torch.Tensor, f_loss_t: Callable, online: bool) -> None:
"""
:param model: the neural network
:type model: nn.Module
:param optimizer: the optimizer for the network
:type optimizer: torch.optim.Optimizer
:param x_seq: the input sequence
:type x_seq: torch.Tensor
:param target_seq: the output sequence
:type target_seq: torch.Tensor
:param f_loss_t: the loss function, which should has the formulation of ``def f_loss_t(x_t, y_t) -> torch.Tensor``
:type f_loss_t: Callable
:param online: whether online update parameters or accumulate gradients through time steps
:type online: bool


The OTTT online training method is proposed by `Online Training Through Time for Spiking Neural Networks <https://openreview.net/forum?id=Siv3nHYHheI>`_.
This function can also be used for SLTT training method proposed by `Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks <https://openaccess.thecvf.com/content/ICCV2023/html/Meng_Towards_Memory-_and_Time-Efficient_Backpropagation_for_Training_Spiking_Neural_Networks_ICCV_2023_paper.html>`_ .

Example:

.. code-block:: python

from spikingjelly.activation_based import neuron, layer, functional

net = layer.OTTTSequential(
nn.Linear(8, 4),
neuron.OTTTLIFNode(),
nn.Linear(4, 2),
neuron.LIFNode()
)

optimizer = torch.optim.SGD(net.parameters(), lr=0.1)

T = 4
N = 2
online = True
for epoch in range(2):

x_seq = torch.rand([N, T, 8])
target_seq = torch.rand([N, T, 2])

functional.ottt_online_training(model=net, optimizer=optimizer, x_seq=x_seq, target_seq=target_seq, f_loss_t=F.mse_loss, online=online)
functional.reset_net(net)

"""

# input x_seq/target_seq: [B, T, ...]
# transpose to [T, B, ...]
x_seq = x_seq.transpose(0, 1)
target_seq = target_seq.transpose(0, 1)
T = x_seq.shape[0]

batch_loss = 0.
y_all = []
if not online:
optimizer.zero_grad()
for t in range(T):
if online:
optimizer.zero_grad()

y_t = model(x_seq[t])
loss = f_loss_t(y_t, target_seq[t].contiguous())

loss.backward()

# update params
if online:
optimizer.step()

batch_loss += loss.data
y_all.append(y_t.detach())

if not online:
optimizer.step()

# y_all: [B, T, ...]
y_all = torch.stack(y_all, dim=1)

return batch_loss, y_all


+ 227
- 0
spikingjelly/activation_based/layer.py View File

@@ -10,6 +10,7 @@ from torch.nn.common_types import _size_any_t, _size_1_t, _size_2_t, _size_3_t,
from typing import Optional, List, Tuple, Union
from typing import Callable
from torch.nn.modules.batchnorm import _BatchNorm
import numpy as np


class MultiStepContainer(nn.Sequential, base.MultiStepModule):
@@ -2534,3 +2535,229 @@ class TemporalEffectiveBatchNorm3d(TemporalEffectiveBatchNormNd):
def multi_step_forward(self, x_seq: torch.Tensor):
# x.shape = [T, N, C, H, W, D]
return self.bn(x_seq) * self.scale.view(-1, 1, 1, 1, 1, 1)


# OTTT modules

class ReplaceforGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, x, x_r):
return x_r

@staticmethod
def backward(ctx, grad):
return (grad, grad)


class GradwithTrace(nn.Module):
def __init__(self, module):
"""
* :ref:`API in English <GradwithTrace-en>`

.. _GradwithTrace-cn:

:param module: 需要包装的模块

用于随时间在线训练时,根据神经元的迹计算梯度
出处:'Online Training Through Time for Spiking Neural Networks <https://openreview.net/forum?id=Siv3nHYHheI>'

* :ref:`中文 API <GradwithTrace-cn>`

.. _GradwithTrace-en:

:param module: the module that requires wrapping

Used for online training through time, calculate gradients by the traces of neurons
Reference: 'Online Training Through Time for Spiking Neural Networks <https://openreview.net/forum?id=Siv3nHYHheI>'

"""
super().__init__()
self.module = module

def forward(self, x: Tensor):
# x: [spike, trace], defined in OTTTLIFNode in neuron.py
spike, trace = x[0], x[1]
with torch.no_grad():
out = self.module(spike).detach()

in_for_grad = ReplaceforGrad.apply(spike, trace)
out_for_grad = self.module(in_for_grad)

x = ReplaceforGrad.apply(out_for_grad, out)

return x


class SpikeTraceOp(nn.Module):
def __init__(self, module):
"""
* :ref:`API in English <SpikeTraceOp-en>`

.. _SpikeTraceOp-cn:

:param module: 需要包装的模块

对脉冲和迹进行相同的运算,如Dropout,AvgPool等

* :ref:`中文 API <GradwithTrace-cn>`

.. _SpikeTraceOp-en:

:param module: the module that requires wrapping

perform the same operations for spike and trace, such as Dropout, Avgpool, etc.

"""
super().__init__()
self.module = module

def forward(self, x: Tensor):
# x: [spike, trace], defined in OTTTLIFNode in neuron.py
spike, trace = x[0], x[1]
spike = self.module(spike)
with torch.no_grad():
trace = self.module(trace)

x = [spike, trace]

return x


class OTTTSequential(nn.Sequential):
def __init__(self, *args):
super().__init__(*args)

def forward(self, input):
for module in self:
if not isinstance(input, list):
input = module(input)
else:
if len(list(module.parameters())) > 0: # e.g., Conv2d, Linear, etc.
module = GradwithTrace(module)
else: # e.g., Dropout, AvgPool, etc.
module = SpikeTraceOp(module)
input = module(input)
return input


# weight standardization modules

class WSConv2d(Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: _size_2_t,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
step_mode: str = 's',
gain: bool = True,
eps: float = 1e-4
) -> None:
"""
* :ref:`API in English <WSConv2d-en>`

.. _WSConv2d-cn:

:param gain: 是否对权重引入可学习的缩放系数
:type gain: bool

:param eps: 预防数值问题的小量
:type eps: float

其他的参数API参见 :class:`Conv2d`

* :ref:`中文 API <WSConv2d-cn>`

.. _WSConv2d-en:

:param gain: whether introduce learnable scale factors for weights
:type step_mode: bool

:param eps: a small number to prevent numerical problems
:type eps: float

Refer to :class:`Conv2d` for other parameters' API
"""
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, step_mode)
if gain:
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1))
else:
self.gain = None
self.eps = eps

def get_weight(self):
fan_in = np.prod(self.weight.shape[1:])
mean = torch.mean(self.weight, axis=[1, 2, 3], keepdims=True)
var = torch.var(self.weight, axis=[1, 2, 3], keepdims=True)
weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5)
if self.gain is not None:
weight = weight * self.gain
return weight

def _forward(self, x: Tensor):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

def forward(self, x: Tensor):
if self.step_mode == 's':
x = self._forward(x)

elif self.step_mode == 'm':
if x.dim() != 5:
raise ValueError(f'expected x with shape [T, N, C, H, W], but got x with shape {x.shape}!')
x = functional.seq_to_ann_forward(x, self._forward)

return x


class WSLinear(Linear):
def __init__(self, in_features: int, out_features: int, bias: bool = True, step_mode='s', gain=True, eps=1e-4) -> None:
"""
* :ref:`API in English <WSLinear-en>`

.. _WSLinear-cn:

:param gain: 是否对权重引入可学习的缩放系数
:type gain: bool

:param eps: 预防数值问题的小量
:type eps: float

其他的参数API参见 :class:`Linear`

* :ref:`中文 API <WSLinear-cn>`

.. _WSLinear-en:

:param gain: whether introduce learnable scale factors for weights
:type step_mode: bool

:param eps: a small number to prevent numerical problems
:type eps: float

Refer to :class:`Linear` for other parameters' API
"""
super().__init__(in_features, out_features, bias, step_mode)
if gain:
self.gain = nn.Parameter(torch.ones(self.out_channels, 1))
else:
self.gain = None
self.eps = eps

def get_weight(self):
fan_in = np.prod(self.weight.shape[1:])
mean = torch.mean(self.weight, axis=[1], keepdims=True)
var = torch.var(self.weight, axis=[1], keepdims=True)
weight = (self.weight - mean) / ((var * fan_in + self.eps) ** 0.5)
if self.gain is not None:
weight = weight * self.gain
return weight

def forward(self, x: Tensor):
return F.linear(x, self.get_weight(), self.bias)

+ 251
- 0
spikingjelly/activation_based/model/spiking_vggws_ottt.py View File

@@ -0,0 +1,251 @@
import torch
import torch.nn as nn
from copy import deepcopy
from .. import functional, neuron, layer

__all__ = [
'OTTTSpikingVGG',
'ottt_spiking_vggws',
'ottt_spiking_vgg11','ottt_spiking_vgg11_ws',
'ottt_spiking_vgg13','ottt_spiking_vgg13_ws',
'ottt_spiking_vgg16','ottt_spiking_vgg16_ws',
'ottt_spiking_vgg19','ottt_spiking_vgg19_ws',
]

# modified by https://github.com/pytorch/vision/blob/main/torchvision/models/vgg.py


class Scale(nn.Module):

def __init__(self, scale):
super(Scale, self).__init__()
self.scale = scale

def forward(self, x):
return x * self.scale


class OTTTSpikingVGG(nn.Module):

def __init__(self, cfg, weight_standardization=True, num_classes=1000, init_weights=True,
spiking_neuron: callable = None, light_classifier=True, drop_rate=0., **kwargs):
super(OTTTSpikingVGG, self).__init__()
self.fc_hw = kwargs.get('fc_hw', 1)
if weight_standardization:
ws_scale = 2.74
else:
ws_scale = 1.
self.neuron = spiking_neuron
self.features = self.make_layers(cfg=cfg, weight_standardization=weight_standardization,
neuron=spiking_neuron, drop_rate=0., **kwargs)
if light_classifier:
self.classifier = layer.OTTTSequential(
layer.AdaptiveAvgPool2d((self.fc_hw, self.fc_hw)),
layer.Flatten(1),
layer.Linear(512*(self.fc_hw**2), num_classes),
)
else:
Linear = layer.WSLinear if weight_standardization else layer.Linear
self.classifier = layer.OTTTSequential(
layer.AdaptiveAvgPool2d((7, 7)),
layer.Flatten(1),
Linear(512 * 7 * 7, 4096),
spiking_neuron(**deepcopy(kwargs)),
Scale(ws_scale),
layer.Dropout(),
Linear(4096, 4096),
spiking_neuron(**deepcopy(kwargs)),
Scale(ws_scale),
layer.Dropout(),
layer.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()

def forward(self, x):
x = self.features(x)
x = self.classifier(x)

return x

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, layer.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, layer.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, layer.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

@staticmethod
def make_layers(cfg, weight_standardization=True, neuron: callable = None, drop_rate=0., **kwargs):
layers = []
in_channels = 3
Conv2d = layer.WSConv2d if weight_standardization else layer.Conv2d
for v in cfg:
if v == 'M':
layers += [layer.MaxPool2d(kernel_size=2, stride=2)]
elif v == 'A':
layers += [layer.AvgPool2d(kernel_size=2, stride=2)]
else:
conv2d = Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, neuron(**deepcopy(kwargs))]
if weight_standardization:
layers += [Scale(2.74)]
in_channels = v
if drop_rate > 0.:
layers += [layer.Dropout(drop_rate)]
return layer.OTTTSequential(*layers)




cfgs = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],

'S': [64, 128, 'A', 256, 256, 'A', 512, 512, 'A', 512, 512],
}


def _spiking_vgg(arch, cfg, weight_standardization, spiking_neuron: callable = None, **kwargs):
model = OTTTSpikingVGG(cfg=cfgs[cfg], weight_standardization=weight_standardization, spiking_neuron=spiking_neuron, **kwargs)
return model



def ottt_spiking_vggws(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG (sWS), model used in 'Online Training Through Time for Spiking Neural Networks <https://openreview.net/forum?id=Siv3nHYHheI>'
:rtype: torch.nn.Module
"""

return _spiking_vgg('vggws', 'S', True, spiking_neuron, light_classifier=True, **kwargs)




def ottt_spiking_vgg11(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG-11
:rtype: torch.nn.Module
"""

return _spiking_vgg('vgg11', 'A', False, spiking_neuron, light_classifier=False, **kwargs)




def ottt_spiking_vgg11_ws(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG-11 with weight standardization
:rtype: torch.nn.Module
"""

return _spiking_vgg('vgg11_ws', 'A', True, spiking_neuron, light_classifier=False, **kwargs)



def ottt_spiking_vgg13(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG-13
:rtype: torch.nn.Module
"""

return _spiking_vgg('vgg13', 'B', False, spiking_neuron, light_classifier=False, **kwargs)




def ottt_spiking_vgg13_ws(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG-11 with weight standardization
:rtype: torch.nn.Module
"""

return _spiking_vgg('vgg13_ws', 'B', True, spiking_neuron, light_classifier=False, **kwargs)




def ottt_spiking_vgg16(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG-16
:rtype: torch.nn.Module
"""

return _spiking_vgg('vgg16', 'D', False, spiking_neuron, light_classifier=False, **kwargs)



def ottt_spiking_vgg16_ws(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG-16 with weight standardization
:rtype: torch.nn.Module
"""

return _spiking_vgg('vgg16_ws', 'D', True, spiking_neuron, light_classifier=False, **kwargs)



def ottt_spiking_vgg19(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG-19
:rtype: torch.nn.Module
"""

return _spiking_vgg('vgg19', 'E', False, spiking_neuron, light_classifier=False, **kwargs)



def ottt_spiking_vgg19_ws(spiking_neuron: callable = neuron.OTTTLIFNode, **kwargs):
"""
:param spiking_neuron: a spiking neuron layer
:type spiking_neuron: callable
:param kwargs: kwargs for `spiking_neuron`
:type kwargs: dict
:return: Spiking VGG-19 with weight standardization
:rtype: torch.nn.Module
"""

return _spiking_vgg('vgg19_ws', 'E', True, spiking_neuron, light_classifier=False, **kwargs)



+ 750
- 0
spikingjelly/activation_based/neuron.py View File

@@ -21,6 +21,50 @@ except BaseException as e:
cuda_utils = None


class SimpleBaseNode(base.MemoryModule):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
step_mode='s'):
"""
A simple version of ``BaseNode``. The user can modify this neuron easily.
"""
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
self.surrogate_function = surrogate_function
self.detach_reset = detach_reset
self.step_mode = step_mode
self.register_memory(name='v', value=0.)

def single_step_forward(self, x: torch.Tensor):

self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike

def neuronal_charge(self, x: torch.Tensor):
raise NotImplementedError

def neuronal_fire(self):
return self.surrogate_function(self.v - self.v_threshold)

def neuronal_reset(self, spike):
if self.detach_reset:
spike_d = spike.detach()
else:
spike_d = spike

if self.v_reset is None:
# soft reset
self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold)

else:
# hard reset
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset)



class BaseNode(base.MemoryModule):
def __init__(self, v_threshold: float = 1., v_reset: float = 0.,
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False,
@@ -2373,3 +2417,709 @@ class GatedLIFNode(base.MemoryModule):
self.v = self.u
y_seq.append(spike)
return torch.stack(y_seq)


##########################################################################################################
# DSR modules
##########################################################################################################

import torch.distributed as dist

class DSRIFNode(base.MemoryModule):
def __init__(self, T: int = 20, v_threshold: float = 6., alpha: float = 0.5, v_threshold_training: bool = True,
v_threshold_grad_scaling: float = 1.0, v_threshold_lower_bound: float = 0.01, step_mode='m',
backend='torch', **kwargs):

"""
* :ref:`中文API <DSRIFNode.__init__-cn>`

.. _DSRIFNode.__init__-cn:

:param T: 时间步长
:type T: int

:param v_threshold: 神经元的阈值电压初始值
:type v_threshold: float

:param alpha: 放电阈值的缩放因子
:type alpha: float

:param v_threshold_training: 是否将阈值电压设置为可学习参数,默认为`'True'`
:type v_threshold_training: bool

:param v_threshold_grad_scaling: 对放电阈值的梯度进行缩放的缩放因子
:type v_threshold_grad_scaling: float

:param v_threshold_lower_bound: 训练过程中,阈值电压能取到的最小值
:type v_threshold_lower_bound: float

:param step_mode: 步进模式,只支持 `'m'` (多步)
:type step_mode: str

:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的。DSR-IF只支持torch
:type backend: str

模型出处:`Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation
<https://arxiv.org/pdf/2205.00459.pdf>`.


* :ref:`API in English <DSRIFNode.__init__-en>`

.. _DSRIFNode.__init__-en:

:param T: time-step
:type T: int

:param v_threshold: initial menbrane potential threshold
:type v_threshold: float

:param alpha: the scaling factor for the menbrane potential threshold
:type alpha: float

:param v_threshold_training: whether the menbrane potential threshold is trained, default: `'True'`
:type v_threshold_training: bool

:param v_threshold_grad_scaling: the scaling factor for the gradient of the menbrane potential threshold
:type v_threshold_grad_scaling: float

:param v_threshold_lower_bound: the minimum of the menbrane potential threshold during training
:type v_threshold_lower_bound: float

:param step_mode: step mode, only support `'m'` (multi-step)
:type step_mode: str

:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
:type backend: str


DSR IF neuron refers to `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation
<https://arxiv.org/pdf/2205.00459.pdf>`.
"""

assert isinstance(T, int) and T is not None
assert isinstance(v_threshold, float) and v_threshold >= v_threshold_lower_bound
assert isinstance(alpha, float) and alpha > 0.0 and alpha <= 1.0
assert isinstance(v_threshold_lower_bound, float) and v_threshold_lower_bound > 0.0
assert step_mode == 'm'

super().__init__()
self.backend = backend
self.step_mode = step_mode
self.T = T
if v_threshold_training:
self.v_threshold = nn.Parameter(torch.tensor(v_threshold))
else:
self.v_threshold = torch.tensor(v_threshold)
self.alpha = alpha
self.v_threshold_lower_bound = v_threshold_lower_bound
self.v_threshold_grad_scaling = v_threshold_grad_scaling

@property
def supported_backends(self):
return 'torch'

def extra_repr(self):
with torch.no_grad():
T = self.T
v_threshold = self.v_threshold
alpha = self.alpha
v_threshold_lower_bound = self.v_threshold_lower_bound
v_threshold_grad_scaling = self.v_threshold_grad_scaling
return f', T={T}' + f', init_vth={v_threshold}' + f', alpha={alpha}' + f', vth_bound={v_threshold_lower_bound}' + f', vth_g_scale={v_threshold_grad_scaling}'

def multi_step_forward(self, x_seq: torch.Tensor):
with torch.no_grad():
self.v_threshold.copy_(
F.relu(self.v_threshold - self.v_threshold_lower_bound) + self.v_threshold_lower_bound)
iffunc = self.DSRIFFunction.apply
y_seq = iffunc(x_seq, self.T, self.v_threshold, self.alpha, self.v_threshold_grad_scaling)
return y_seq


class DSRIFFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, T=10, v_threshold=1.0, alpha=0.5, v_threshold_grad_scaling=1.0):
ctx.save_for_backward(inp)

mem_potential = torch.zeros_like(inp[0]).to(inp.device)
spikes = []

for t in range(inp.size(0)):
mem_potential = mem_potential + inp[t]
spike = ((mem_potential >= alpha * v_threshold).float() * v_threshold).float()
mem_potential = mem_potential - spike
spikes.append(spike)
output = torch.stack(spikes)

ctx.T = T
ctx.v_threshold = v_threshold
ctx.v_threshold_grad_scaling = v_threshold_grad_scaling
return output

@staticmethod
def backward(ctx, grad_output):
with torch.no_grad():
inp = ctx.saved_tensors[0]
T = ctx.T
v_threshold = ctx.v_threshold
v_threshold_grad_scaling = ctx.v_threshold_grad_scaling

input_rate_coding = torch.mean(inp, 0)
grad_output_coding = torch.mean(grad_output, 0) * T

input_grad = grad_output_coding.clone()
input_grad[(input_rate_coding < 0) | (input_rate_coding > v_threshold)] = 0
input_grad = torch.stack([input_grad for _ in range(T)]) / T

v_threshold_grad = grad_output_coding.clone()
v_threshold_grad[input_rate_coding <= v_threshold] = 0
v_threshold_grad = torch.sum(v_threshold_grad) * v_threshold_grad_scaling
if v_threshold_grad.is_cuda and torch.cuda.device_count() != 1:
try:
dist.all_reduce(v_threshold_grad, op=dist.ReduceOp.SUM)
except:
raise RuntimeWarning(
'Something wrong with the `all_reduce` operation when summing up the gradient of v_threshold from multiple gpus. Better check the gpu status and try DistributedDataParallel.')

return input_grad, None, v_threshold_grad, None, None


class DSRLIFNode(base.MemoryModule):
def __init__(self, T: int = 20, v_threshold: float = 1., tau: float = 2.0, delta_t: float = 0.05,
alpha: float = 0.3, v_threshold_training: bool = True,
v_threshold_grad_scaling: float = 1.0, v_threshold_lower_bound: float = 0.1, step_mode='m',
backend='torch', **kwargs):

"""
* :ref:`中文API <DSRLIFNode.__init__-cn>`

.. _DSRLIFNode.__init__-cn:

:param T: 时间步长
:type T: int

:param v_threshold: 神经元的阈值电压初始值
:type v_threshold: float

:param tau: 膜电位时间常数
:type tau: float

:param delta_t: 对微分方程形式的LIF模型进行离散化的步长
:type delta_t: float

:param alpha: 放电阈值的缩放因子
:type alpha: float

:param v_threshold_training: 是否将阈值电压设置为可学习参数,默认为`'True'`
:type v_threshold_training: bool

:param v_threshold_grad_scaling: 对放电阈值的梯度进行缩放的缩放因子
:type v_threshold_grad_scaling: float

:param v_threshold_lower_bound: 训练过程中,阈值电压能取到的最小值
:type v_threshold_lower_bound: float

:param step_mode: 步进模式,只支持 `'m'` (多步)
:type step_mode: str

:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的。DSR-IF只支持torch
:type backend: str

模型出处:`Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation
<https://arxiv.org/pdf/2205.00459.pdf>`.


* :ref:`API in English <DSRLIFNode.__init__-en>`

.. _DSRLIFNode.__init__-en:

:param T: time-step
:type T: int

:param v_threshold: initial menbrane potential threshold
:type v_threshold: float

:param tau: membrane time constant
:type tau: float

:param delta_t: discretization step for discretizing the ODE version of the LIF model
:type delta_t: float

:param alpha: the scaling factor for the menbrane potential threshold
:type alpha: float

:param v_threshold_training: whether the menbrane potential threshold is trained, default: `'True'`
:type v_threshold_training: bool

:param v_threshold_grad_scaling: the scaling factor for the gradient of the menbrane potential threshold
:type v_threshold_grad_scaling: float

:param v_threshold_lower_bound: the minimum of the menbrane potential threshold during training
:type v_threshold_lower_bound: float

:param step_mode: step mode, only support `'m'` (multi-step)
:type step_mode: str

:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode
:type backend: str


DSR LIF neuron refers to `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation
<https://arxiv.org/pdf/2205.00459.pdf>`.
"""

assert isinstance(T, int) and T is not None
assert isinstance(v_threshold, float) and v_threshold >= v_threshold_lower_bound
assert isinstance(alpha, float) and alpha > 0.0 and alpha <= 1.0
assert isinstance(v_threshold_lower_bound, float) and v_threshold_lower_bound > 0.0
assert step_mode == 'm'

super().__init__()
self.backend = backend
self.step_mode = step_mode
self.T = T
if v_threshold_training:
self.v_threshold = nn.Parameter(torch.tensor(v_threshold))
else:
self.v_threshold = torch.tensor(v_threshold)
self.tau = tau
self.delta_t = delta_t
self.alpha = alpha
self.v_threshold_lower_bound = v_threshold_lower_bound
self.v_threshold_grad_scaling = v_threshold_grad_scaling

@property
def supported_backends(self):
return 'torch'

def extra_repr(self):
with torch.no_grad():
T = self.T
v_threshold = self.v_threshold
tau = self.tau
delta_t = self.delta_t
alpha = self.alpha
v_threshold_lower_bound = self.v_threshold_lower_bound
v_threshold_grad_scaling = self.v_threshold_grad_scaling
return f', T={T}' + f', init_vth={v_threshold}' + f', tau={tau}' + f', dt={delta_t}' + f', alpha={alpha}' + \
f', vth_bound={v_threshold_lower_bound}' + f', vth_g_scale={v_threshold_grad_scaling}'

def multi_step_forward(self, x_seq: torch.Tensor):
with torch.no_grad():
self.v_threshold.copy_(
F.relu(self.v_threshold - self.v_threshold_lower_bound) + self.v_threshold_lower_bound)
liffunc = self.DSRLIFFunction.apply
y_seq = liffunc(x_seq, self.T, self.v_threshold, self.tau, self.delta_t, self.alpha,
self.v_threshold_grad_scaling)
return y_seq

@classmethod
def weight_rate_spikes(cls, data, tau, delta_t):
T = data.shape[0]
chw = data.size()[2:]
data_reshape = data.permute(list(range(1, len(chw) + 2)) + [0])
weight = torch.tensor([math.exp(-1 / tau * (delta_t * T - ii * delta_t)) for ii in range(1, T + 1)]).to(
data_reshape.device)
return (weight * data_reshape).sum(dim=len(chw) + 1) / weight.sum()

class DSRLIFFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp, T, v_threshold, tau, delta_t=0.05, alpha=0.3, v_threshold_grad_scaling=1.0):
ctx.save_for_backward(inp)

mem_potential = torch.zeros_like(inp[0]).to(inp.device)
beta = math.exp(-delta_t / tau)

spikes = []
for t in range(inp.size(0)):
mem_potential = beta * mem_potential + (1 - beta) * inp[t]
spike = ((mem_potential >= alpha * v_threshold).float() * v_threshold).float()
mem_potential = mem_potential - spike
spikes.append(spike / delta_t)
output = torch.stack(spikes)

ctx.T = T
ctx.v_threshold = v_threshold
ctx.tau = tau
ctx.delta_t = delta_t
ctx.v_threshold_grad_scaling = v_threshold_grad_scaling
return output

@staticmethod
def backward(ctx, grad_output):
inp = ctx.saved_tensors[0]
T = ctx.T
v_threshold = ctx.v_threshold
delta_t = ctx.delta_t
tau = ctx.tau
v_threshold_grad_scaling = ctx.v_threshold_grad_scaling

input_rate_coding = DSRLIFNode.weight_rate_spikes(inp, tau, delta_t)
grad_output_coding = DSRLIFNode.weight_rate_spikes(grad_output, tau, delta_t) * T

indexes = (input_rate_coding > 0) & (input_rate_coding < v_threshold / delta_t * tau)
input_grad = torch.zeros_like(grad_output_coding)
input_grad[indexes] = grad_output_coding[indexes].clone() / tau
input_grad = torch.stack([input_grad for _ in range(T)]) / T

v_threshold_grad = grad_output_coding.clone()
v_threshold_grad[input_rate_coding <= v_threshold / delta_t * tau] = 0
v_threshold_grad = torch.sum(v_threshold_grad) * delta_t * v_threshold_grad_scaling
if v_threshold_grad.is_cuda and torch.cuda.device_count() != 1:
try:
dist.all_reduce(v_threshold_grad, op=dist.ReduceOp.SUM)
except:
raise RuntimeWarning('Something wrong with the `all_reduce` operation when summing up the gradient of v_threshold from multiple gpus. Better check the gpu status and try DistributedDataParallel.')

return input_grad, None, v_threshold_grad, None, None, None, None


##########################################################################################################
# OTTT modules
##########################################################################################################

class OTTTLIFNode(LIFNode):
def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1.,
v_reset: float = None, surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = True, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <OTTTLIFNode.__init__-en>`

.. _OTTTLIFNode.__init__-cn:

:param tau: 膜电位时间常数
:type tau: float

:param decay_input: 输入是否也会参与衰减
:type decay_input: bool

:param v_threshold: 神经元的阈值电压
:type v_threshold: float

:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float

:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable

:param detach_reset: 是否将reset过程的计算图分离。该参数在本模块中不起作用,仅为保持代码统一而保留
:type detach_reset: bool

:param step_mode: 步进模式,为了保证神经元的显存占用小,仅可以为 `'s'` (单步)
:type step_mode: str

:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str

:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool

神经元模型出处:`Online Training Through Time for Spiking Neural Networks <https://arxiv.org/pdf/2210.04195.pdf>`
模型正向传播和Leaky Integrate-and-Fire神经元相同;用于随时间在线训练


* :ref:`中文API <OTTTLIFNode.__init__-cn>`

.. _OTTTLIFNode.__init__-en:

:param tau: membrane time constant
:type tau: float

:param decay_input: whether the input will decay
:type decay_input: bool

:param v_threshold: threshold of this neurons layer
:type v_threshold: float

:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float

:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable

:param detach_reset: whether detach the computation graph of reset in backward. this parameter does not take any effect in
the module, and is retained solely for code consistency
:type detach_reset: bool

:param step_mode: the step mode, which can solely be `s` (single-step) to guarantee the memory-efficient computation
:type step_mode: str

:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str

:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool

OTTT LIF neuron refers to `Online Training Through Time for Spiking Neural Networks <https://arxiv.org/pdf/2210.04195.pdf>`
The forward propagation is the same as the Leaky Integrate-and-Fire neuron; used for online training through time

"""

super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
assert step_mode == 's', "Please use single-step mode to enable memory-efficient training."
"""
膜电位将在前向传播过程中重新登记为缓存,以支持多卡分布式训练的情况下保留信息在各时刻进行多次反向传播

membrane potential will be registered as buffer during forward, to support multiple backpropagation for all time steps with
reserved informtion under distributed training on multiple GPUs
"""
self._memories.pop('v')

def reset(self):
super().reset()
if hasattr(self, 'v'):
del self.v
if hasattr(self, 'trace'):
del self.trace

@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch')
else:
raise ValueError(self.step_mode)

def neuronal_charge(self, x: torch.Tensor):
self.v = self.v.detach()

if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau)

else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau)

@staticmethod
@torch.jit.script
def track_trace(spike: torch.Tensor, trace: torch.Tensor, tau: float):
with torch.no_grad():
trace = trace * (1. - 1. / tau) + spike
return trace


def single_step_forward(self, x: torch.Tensor):
"""
训练时,输出脉冲和迹;推理时,输出脉冲
训练时需要将后续参数模块用layer.py中定义的GradwithTrace进行包装,根据迹计算梯度
output spike and trace during training; output spike during inference
during training, successive parametric modules shoule be wrapped by GradwithTrace defined in layer.py, to calculate gradients with traces
"""

if not hasattr(self, 'v'):
if self.v_reset is None:
self.register_buffer('v', torch.zeros_like(x))
else:
self.register_buffer('v', torch.ones_like(x) * self.v_reset)

if self.training:
if not hasattr(self, 'trace'):
self.register_buffer('trace', torch.zeros_like(x))
if self.backend == 'torch':
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)

self.trace = self.track_trace(spike, self.trace, self.tau)

return [spike, self.trace]
else:
raise ValueError(self.backend)
else:
if self.v_reset is None:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v,
self.v_threshold, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v,
self.v_threshold,
self.tau)
else:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v,
self.v_threshold,
self.v_reset, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v,
self.v_threshold,
self.v_reset,
self.tau)
return spike


##########################################################################################################
# SLTT modules
##########################################################################################################

class SLTTLIFNode(LIFNode):
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1.,
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(),
detach_reset: bool = True, step_mode='s', backend='torch', store_v_seq: bool = False):
"""
* :ref:`API in English <SLTTLIFNode.__init__-en>`

.. _SLTTLIFNode.__init__-cn:

:param tau: 膜电位时间常数
:type tau: float

:param decay_input: 输入是否也会参与衰减
:type decay_input: bool

:param v_threshold: 神经元的阈值电压
:type v_threshold: float

:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``;
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold``
:type v_reset: float

:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数
:type surrogate_function: Callable

:param detach_reset: 是否将reset过程的计算图分离。该参数在本模块中不起作用,仅为保持代码统一而保留
:type detach_reset: bool

:param step_mode: 步进模式,为了保证神经元的显存占用小,仅可以为 `'s'` (单步)
:type step_mode: str

:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的
:type backend: str

:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]``
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。
通常设置成 ``False`` ,可以节省内存
:type store_v_seq: bool

神经元模型出处:`Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks
<https://arxiv.org/pdf/2302.14311.pdf>`.模型正向传播和Leaky Integrate-and-Fire神经元相同.


* :ref:`中文API <SLTTLIFNode.__init__-cn>`

.. _SLTTLIFNode.__init__-en:

:param tau: membrane time constant
:type tau: float

:param decay_input: whether the input will decay
:type decay_input: bool

:param v_threshold: threshold of this neurons layer
:type v_threshold: float

:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset``
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike
:type v_reset: float

:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward
:type surrogate_function: Callable

:param detach_reset: whether detach the computation graph of reset in backward. this parameter does not take any effect in
the module, and is retained solely for code consistency
:type detach_reset: bool

:param step_mode: the step mode, which can solely be `s` (single-step) to guarantee the memory-efficient computation
:type step_mode: str

:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported,
using ``'cupy'`` backend will have the fastest training speed
:type backend: str

:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``,
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the
memory consumption
:type store_v_seq: bool

SLTT LIF neuron refers to `Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks
<https://arxiv.org/pdf/2302.14311.pdf>`. The forward propagation is the same as the Leaky Integrate-and-Fire neuron's.

"""

super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq)
assert step_mode == 's', "Please use single-step mode to enable memory-efficient training."
self._memories.pop('v')

def reset(self):
super().reset()
if hasattr(self, 'v'):
del self.v

@property
def supported_backends(self):
if self.step_mode == 's':
return ('torch')
else:
raise ValueError(self.step_mode)

def neuronal_charge(self, x: torch.Tensor):
self.v = self.v.detach()

if self.decay_input:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau)

else:
if self.v_reset is None or self.v_reset == 0.:
self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau)
else:
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau)

def single_step_forward(self, x: torch.Tensor):

if not hasattr(self, 'v'):
if self.v_reset is None:
self.register_buffer('v', torch.zeros_like(x))
else:
self.register_buffer('v', torch.ones_like(x) * self.v_reset)

if self.training:
if self.backend == 'torch':
self.neuronal_charge(x)
spike = self.neuronal_fire()
self.neuronal_reset(spike)
return spike
else:
raise ValueError(self.backend)
else:
if self.v_reset is None:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v,
self.v_threshold, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v,
self.v_threshold,
self.tau)
else:
if self.decay_input:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v,
self.v_threshold,
self.v_reset, self.tau)
else:
spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v,
self.v_threshold,
self.v_reset,
self.tau)
return spike

Loading…
Cancel
Save
Baidu
map