|
|
|
@@ -21,6 +21,50 @@ except BaseException as e: |
|
|
|
cuda_utils = None |
|
|
|
|
|
|
|
|
|
|
|
class SimpleBaseNode(base.MemoryModule): |
|
|
|
def __init__(self, v_threshold: float = 1., v_reset: float = 0., |
|
|
|
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, |
|
|
|
step_mode='s'): |
|
|
|
""" |
|
|
|
A simple version of ``BaseNode``. The user can modify this neuron easily. |
|
|
|
""" |
|
|
|
super().__init__() |
|
|
|
self.v_threshold = v_threshold |
|
|
|
self.v_reset = v_reset |
|
|
|
self.surrogate_function = surrogate_function |
|
|
|
self.detach_reset = detach_reset |
|
|
|
self.step_mode = step_mode |
|
|
|
self.register_memory(name='v', value=0.) |
|
|
|
|
|
|
|
def single_step_forward(self, x: torch.Tensor): |
|
|
|
|
|
|
|
self.neuronal_charge(x) |
|
|
|
spike = self.neuronal_fire() |
|
|
|
self.neuronal_reset(spike) |
|
|
|
return spike |
|
|
|
|
|
|
|
def neuronal_charge(self, x: torch.Tensor): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def neuronal_fire(self): |
|
|
|
return self.surrogate_function(self.v - self.v_threshold) |
|
|
|
|
|
|
|
def neuronal_reset(self, spike): |
|
|
|
if self.detach_reset: |
|
|
|
spike_d = spike.detach() |
|
|
|
else: |
|
|
|
spike_d = spike |
|
|
|
|
|
|
|
if self.v_reset is None: |
|
|
|
# soft reset |
|
|
|
self.v = self.jit_soft_reset(self.v, spike_d, self.v_threshold) |
|
|
|
|
|
|
|
else: |
|
|
|
# hard reset |
|
|
|
self.v = self.jit_hard_reset(self.v, spike_d, self.v_reset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BaseNode(base.MemoryModule): |
|
|
|
def __init__(self, v_threshold: float = 1., v_reset: float = 0., |
|
|
|
surrogate_function: Callable = surrogate.Sigmoid(), detach_reset: bool = False, |
|
|
|
@@ -2373,3 +2417,709 @@ class GatedLIFNode(base.MemoryModule): |
|
|
|
self.v = self.u |
|
|
|
y_seq.append(spike) |
|
|
|
return torch.stack(y_seq) |
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################## |
|
|
|
# DSR modules |
|
|
|
########################################################################################################## |
|
|
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
|
|
class DSRIFNode(base.MemoryModule): |
|
|
|
def __init__(self, T: int = 20, v_threshold: float = 6., alpha: float = 0.5, v_threshold_training: bool = True, |
|
|
|
v_threshold_grad_scaling: float = 1.0, v_threshold_lower_bound: float = 0.01, step_mode='m', |
|
|
|
backend='torch', **kwargs): |
|
|
|
|
|
|
|
""" |
|
|
|
* :ref:`中文API <DSRIFNode.__init__-cn>` |
|
|
|
|
|
|
|
.. _DSRIFNode.__init__-cn: |
|
|
|
|
|
|
|
:param T: 时间步长 |
|
|
|
:type T: int |
|
|
|
|
|
|
|
:param v_threshold: 神经元的阈值电压初始值 |
|
|
|
:type v_threshold: float |
|
|
|
|
|
|
|
:param alpha: 放电阈值的缩放因子 |
|
|
|
:type alpha: float |
|
|
|
|
|
|
|
:param v_threshold_training: 是否将阈值电压设置为可学习参数,默认为`'True'` |
|
|
|
:type v_threshold_training: bool |
|
|
|
|
|
|
|
:param v_threshold_grad_scaling: 对放电阈值的梯度进行缩放的缩放因子 |
|
|
|
:type v_threshold_grad_scaling: float |
|
|
|
|
|
|
|
:param v_threshold_lower_bound: 训练过程中,阈值电压能取到的最小值 |
|
|
|
:type v_threshold_lower_bound: float |
|
|
|
|
|
|
|
:param step_mode: 步进模式,只支持 `'m'` (多步) |
|
|
|
:type step_mode: str |
|
|
|
|
|
|
|
:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 |
|
|
|
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的。DSR-IF只支持torch |
|
|
|
:type backend: str |
|
|
|
|
|
|
|
模型出处:`Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation |
|
|
|
<https://arxiv.org/pdf/2205.00459.pdf>`. |
|
|
|
|
|
|
|
|
|
|
|
* :ref:`API in English <DSRIFNode.__init__-en>` |
|
|
|
|
|
|
|
.. _DSRIFNode.__init__-en: |
|
|
|
|
|
|
|
:param T: time-step |
|
|
|
:type T: int |
|
|
|
|
|
|
|
:param v_threshold: initial menbrane potential threshold |
|
|
|
:type v_threshold: float |
|
|
|
|
|
|
|
:param alpha: the scaling factor for the menbrane potential threshold |
|
|
|
:type alpha: float |
|
|
|
|
|
|
|
:param v_threshold_training: whether the menbrane potential threshold is trained, default: `'True'` |
|
|
|
:type v_threshold_training: bool |
|
|
|
|
|
|
|
:param v_threshold_grad_scaling: the scaling factor for the gradient of the menbrane potential threshold |
|
|
|
:type v_threshold_grad_scaling: float |
|
|
|
|
|
|
|
:param v_threshold_lower_bound: the minimum of the menbrane potential threshold during training |
|
|
|
:type v_threshold_lower_bound: float |
|
|
|
|
|
|
|
:param step_mode: step mode, only support `'m'` (multi-step) |
|
|
|
:type step_mode: str |
|
|
|
|
|
|
|
:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode |
|
|
|
:type backend: str |
|
|
|
|
|
|
|
|
|
|
|
DSR IF neuron refers to `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation |
|
|
|
<https://arxiv.org/pdf/2205.00459.pdf>`. |
|
|
|
""" |
|
|
|
|
|
|
|
assert isinstance(T, int) and T is not None |
|
|
|
assert isinstance(v_threshold, float) and v_threshold >= v_threshold_lower_bound |
|
|
|
assert isinstance(alpha, float) and alpha > 0.0 and alpha <= 1.0 |
|
|
|
assert isinstance(v_threshold_lower_bound, float) and v_threshold_lower_bound > 0.0 |
|
|
|
assert step_mode == 'm' |
|
|
|
|
|
|
|
super().__init__() |
|
|
|
self.backend = backend |
|
|
|
self.step_mode = step_mode |
|
|
|
self.T = T |
|
|
|
if v_threshold_training: |
|
|
|
self.v_threshold = nn.Parameter(torch.tensor(v_threshold)) |
|
|
|
else: |
|
|
|
self.v_threshold = torch.tensor(v_threshold) |
|
|
|
self.alpha = alpha |
|
|
|
self.v_threshold_lower_bound = v_threshold_lower_bound |
|
|
|
self.v_threshold_grad_scaling = v_threshold_grad_scaling |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_backends(self): |
|
|
|
return 'torch' |
|
|
|
|
|
|
|
def extra_repr(self): |
|
|
|
with torch.no_grad(): |
|
|
|
T = self.T |
|
|
|
v_threshold = self.v_threshold |
|
|
|
alpha = self.alpha |
|
|
|
v_threshold_lower_bound = self.v_threshold_lower_bound |
|
|
|
v_threshold_grad_scaling = self.v_threshold_grad_scaling |
|
|
|
return f', T={T}' + f', init_vth={v_threshold}' + f', alpha={alpha}' + f', vth_bound={v_threshold_lower_bound}' + f', vth_g_scale={v_threshold_grad_scaling}' |
|
|
|
|
|
|
|
def multi_step_forward(self, x_seq: torch.Tensor): |
|
|
|
with torch.no_grad(): |
|
|
|
self.v_threshold.copy_( |
|
|
|
F.relu(self.v_threshold - self.v_threshold_lower_bound) + self.v_threshold_lower_bound) |
|
|
|
iffunc = self.DSRIFFunction.apply |
|
|
|
y_seq = iffunc(x_seq, self.T, self.v_threshold, self.alpha, self.v_threshold_grad_scaling) |
|
|
|
return y_seq |
|
|
|
|
|
|
|
|
|
|
|
class DSRIFFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
|
|
def forward(ctx, inp, T=10, v_threshold=1.0, alpha=0.5, v_threshold_grad_scaling=1.0): |
|
|
|
ctx.save_for_backward(inp) |
|
|
|
|
|
|
|
mem_potential = torch.zeros_like(inp[0]).to(inp.device) |
|
|
|
spikes = [] |
|
|
|
|
|
|
|
for t in range(inp.size(0)): |
|
|
|
mem_potential = mem_potential + inp[t] |
|
|
|
spike = ((mem_potential >= alpha * v_threshold).float() * v_threshold).float() |
|
|
|
mem_potential = mem_potential - spike |
|
|
|
spikes.append(spike) |
|
|
|
output = torch.stack(spikes) |
|
|
|
|
|
|
|
ctx.T = T |
|
|
|
ctx.v_threshold = v_threshold |
|
|
|
ctx.v_threshold_grad_scaling = v_threshold_grad_scaling |
|
|
|
return output |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def backward(ctx, grad_output): |
|
|
|
with torch.no_grad(): |
|
|
|
inp = ctx.saved_tensors[0] |
|
|
|
T = ctx.T |
|
|
|
v_threshold = ctx.v_threshold |
|
|
|
v_threshold_grad_scaling = ctx.v_threshold_grad_scaling |
|
|
|
|
|
|
|
input_rate_coding = torch.mean(inp, 0) |
|
|
|
grad_output_coding = torch.mean(grad_output, 0) * T |
|
|
|
|
|
|
|
input_grad = grad_output_coding.clone() |
|
|
|
input_grad[(input_rate_coding < 0) | (input_rate_coding > v_threshold)] = 0 |
|
|
|
input_grad = torch.stack([input_grad for _ in range(T)]) / T |
|
|
|
|
|
|
|
v_threshold_grad = grad_output_coding.clone() |
|
|
|
v_threshold_grad[input_rate_coding <= v_threshold] = 0 |
|
|
|
v_threshold_grad = torch.sum(v_threshold_grad) * v_threshold_grad_scaling |
|
|
|
if v_threshold_grad.is_cuda and torch.cuda.device_count() != 1: |
|
|
|
try: |
|
|
|
dist.all_reduce(v_threshold_grad, op=dist.ReduceOp.SUM) |
|
|
|
except: |
|
|
|
raise RuntimeWarning( |
|
|
|
'Something wrong with the `all_reduce` operation when summing up the gradient of v_threshold from multiple gpus. Better check the gpu status and try DistributedDataParallel.') |
|
|
|
|
|
|
|
return input_grad, None, v_threshold_grad, None, None |
|
|
|
|
|
|
|
|
|
|
|
class DSRLIFNode(base.MemoryModule): |
|
|
|
def __init__(self, T: int = 20, v_threshold: float = 1., tau: float = 2.0, delta_t: float = 0.05, |
|
|
|
alpha: float = 0.3, v_threshold_training: bool = True, |
|
|
|
v_threshold_grad_scaling: float = 1.0, v_threshold_lower_bound: float = 0.1, step_mode='m', |
|
|
|
backend='torch', **kwargs): |
|
|
|
|
|
|
|
""" |
|
|
|
* :ref:`中文API <DSRLIFNode.__init__-cn>` |
|
|
|
|
|
|
|
.. _DSRLIFNode.__init__-cn: |
|
|
|
|
|
|
|
:param T: 时间步长 |
|
|
|
:type T: int |
|
|
|
|
|
|
|
:param v_threshold: 神经元的阈值电压初始值 |
|
|
|
:type v_threshold: float |
|
|
|
|
|
|
|
:param tau: 膜电位时间常数 |
|
|
|
:type tau: float |
|
|
|
|
|
|
|
:param delta_t: 对微分方程形式的LIF模型进行离散化的步长 |
|
|
|
:type delta_t: float |
|
|
|
|
|
|
|
:param alpha: 放电阈值的缩放因子 |
|
|
|
:type alpha: float |
|
|
|
|
|
|
|
:param v_threshold_training: 是否将阈值电压设置为可学习参数,默认为`'True'` |
|
|
|
:type v_threshold_training: bool |
|
|
|
|
|
|
|
:param v_threshold_grad_scaling: 对放电阈值的梯度进行缩放的缩放因子 |
|
|
|
:type v_threshold_grad_scaling: float |
|
|
|
|
|
|
|
:param v_threshold_lower_bound: 训练过程中,阈值电压能取到的最小值 |
|
|
|
:type v_threshold_lower_bound: float |
|
|
|
|
|
|
|
:param step_mode: 步进模式,只支持 `'m'` (多步) |
|
|
|
:type step_mode: str |
|
|
|
|
|
|
|
:param backend: 使用哪种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 |
|
|
|
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的。DSR-IF只支持torch |
|
|
|
:type backend: str |
|
|
|
|
|
|
|
模型出处:`Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation |
|
|
|
<https://arxiv.org/pdf/2205.00459.pdf>`. |
|
|
|
|
|
|
|
|
|
|
|
* :ref:`API in English <DSRLIFNode.__init__-en>` |
|
|
|
|
|
|
|
.. _DSRLIFNode.__init__-en: |
|
|
|
|
|
|
|
:param T: time-step |
|
|
|
:type T: int |
|
|
|
|
|
|
|
:param v_threshold: initial menbrane potential threshold |
|
|
|
:type v_threshold: float |
|
|
|
|
|
|
|
:param tau: membrane time constant |
|
|
|
:type tau: float |
|
|
|
|
|
|
|
:param delta_t: discretization step for discretizing the ODE version of the LIF model |
|
|
|
:type delta_t: float |
|
|
|
|
|
|
|
:param alpha: the scaling factor for the menbrane potential threshold |
|
|
|
:type alpha: float |
|
|
|
|
|
|
|
:param v_threshold_training: whether the menbrane potential threshold is trained, default: `'True'` |
|
|
|
:type v_threshold_training: bool |
|
|
|
|
|
|
|
:param v_threshold_grad_scaling: the scaling factor for the gradient of the menbrane potential threshold |
|
|
|
:type v_threshold_grad_scaling: float |
|
|
|
|
|
|
|
:param v_threshold_lower_bound: the minimum of the menbrane potential threshold during training |
|
|
|
:type v_threshold_lower_bound: float |
|
|
|
|
|
|
|
:param step_mode: step mode, only support `'m'` (multi-step) |
|
|
|
:type step_mode: str |
|
|
|
|
|
|
|
:param backend: backend fot this neuron layer, which can be "gemm" or "conv". This option only works for the multi-step mode |
|
|
|
:type backend: str |
|
|
|
|
|
|
|
|
|
|
|
DSR LIF neuron refers to `Training High-Performance Low-Latency Spiking Neural Networks by Differentiation on Spike Representation |
|
|
|
<https://arxiv.org/pdf/2205.00459.pdf>`. |
|
|
|
""" |
|
|
|
|
|
|
|
assert isinstance(T, int) and T is not None |
|
|
|
assert isinstance(v_threshold, float) and v_threshold >= v_threshold_lower_bound |
|
|
|
assert isinstance(alpha, float) and alpha > 0.0 and alpha <= 1.0 |
|
|
|
assert isinstance(v_threshold_lower_bound, float) and v_threshold_lower_bound > 0.0 |
|
|
|
assert step_mode == 'm' |
|
|
|
|
|
|
|
super().__init__() |
|
|
|
self.backend = backend |
|
|
|
self.step_mode = step_mode |
|
|
|
self.T = T |
|
|
|
if v_threshold_training: |
|
|
|
self.v_threshold = nn.Parameter(torch.tensor(v_threshold)) |
|
|
|
else: |
|
|
|
self.v_threshold = torch.tensor(v_threshold) |
|
|
|
self.tau = tau |
|
|
|
self.delta_t = delta_t |
|
|
|
self.alpha = alpha |
|
|
|
self.v_threshold_lower_bound = v_threshold_lower_bound |
|
|
|
self.v_threshold_grad_scaling = v_threshold_grad_scaling |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_backends(self): |
|
|
|
return 'torch' |
|
|
|
|
|
|
|
def extra_repr(self): |
|
|
|
with torch.no_grad(): |
|
|
|
T = self.T |
|
|
|
v_threshold = self.v_threshold |
|
|
|
tau = self.tau |
|
|
|
delta_t = self.delta_t |
|
|
|
alpha = self.alpha |
|
|
|
v_threshold_lower_bound = self.v_threshold_lower_bound |
|
|
|
v_threshold_grad_scaling = self.v_threshold_grad_scaling |
|
|
|
return f', T={T}' + f', init_vth={v_threshold}' + f', tau={tau}' + f', dt={delta_t}' + f', alpha={alpha}' + \ |
|
|
|
f', vth_bound={v_threshold_lower_bound}' + f', vth_g_scale={v_threshold_grad_scaling}' |
|
|
|
|
|
|
|
def multi_step_forward(self, x_seq: torch.Tensor): |
|
|
|
with torch.no_grad(): |
|
|
|
self.v_threshold.copy_( |
|
|
|
F.relu(self.v_threshold - self.v_threshold_lower_bound) + self.v_threshold_lower_bound) |
|
|
|
liffunc = self.DSRLIFFunction.apply |
|
|
|
y_seq = liffunc(x_seq, self.T, self.v_threshold, self.tau, self.delta_t, self.alpha, |
|
|
|
self.v_threshold_grad_scaling) |
|
|
|
return y_seq |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def weight_rate_spikes(cls, data, tau, delta_t): |
|
|
|
T = data.shape[0] |
|
|
|
chw = data.size()[2:] |
|
|
|
data_reshape = data.permute(list(range(1, len(chw) + 2)) + [0]) |
|
|
|
weight = torch.tensor([math.exp(-1 / tau * (delta_t * T - ii * delta_t)) for ii in range(1, T + 1)]).to( |
|
|
|
data_reshape.device) |
|
|
|
return (weight * data_reshape).sum(dim=len(chw) + 1) / weight.sum() |
|
|
|
|
|
|
|
class DSRLIFFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
|
|
def forward(ctx, inp, T, v_threshold, tau, delta_t=0.05, alpha=0.3, v_threshold_grad_scaling=1.0): |
|
|
|
ctx.save_for_backward(inp) |
|
|
|
|
|
|
|
mem_potential = torch.zeros_like(inp[0]).to(inp.device) |
|
|
|
beta = math.exp(-delta_t / tau) |
|
|
|
|
|
|
|
spikes = [] |
|
|
|
for t in range(inp.size(0)): |
|
|
|
mem_potential = beta * mem_potential + (1 - beta) * inp[t] |
|
|
|
spike = ((mem_potential >= alpha * v_threshold).float() * v_threshold).float() |
|
|
|
mem_potential = mem_potential - spike |
|
|
|
spikes.append(spike / delta_t) |
|
|
|
output = torch.stack(spikes) |
|
|
|
|
|
|
|
ctx.T = T |
|
|
|
ctx.v_threshold = v_threshold |
|
|
|
ctx.tau = tau |
|
|
|
ctx.delta_t = delta_t |
|
|
|
ctx.v_threshold_grad_scaling = v_threshold_grad_scaling |
|
|
|
return output |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def backward(ctx, grad_output): |
|
|
|
inp = ctx.saved_tensors[0] |
|
|
|
T = ctx.T |
|
|
|
v_threshold = ctx.v_threshold |
|
|
|
delta_t = ctx.delta_t |
|
|
|
tau = ctx.tau |
|
|
|
v_threshold_grad_scaling = ctx.v_threshold_grad_scaling |
|
|
|
|
|
|
|
input_rate_coding = DSRLIFNode.weight_rate_spikes(inp, tau, delta_t) |
|
|
|
grad_output_coding = DSRLIFNode.weight_rate_spikes(grad_output, tau, delta_t) * T |
|
|
|
|
|
|
|
indexes = (input_rate_coding > 0) & (input_rate_coding < v_threshold / delta_t * tau) |
|
|
|
input_grad = torch.zeros_like(grad_output_coding) |
|
|
|
input_grad[indexes] = grad_output_coding[indexes].clone() / tau |
|
|
|
input_grad = torch.stack([input_grad for _ in range(T)]) / T |
|
|
|
|
|
|
|
v_threshold_grad = grad_output_coding.clone() |
|
|
|
v_threshold_grad[input_rate_coding <= v_threshold / delta_t * tau] = 0 |
|
|
|
v_threshold_grad = torch.sum(v_threshold_grad) * delta_t * v_threshold_grad_scaling |
|
|
|
if v_threshold_grad.is_cuda and torch.cuda.device_count() != 1: |
|
|
|
try: |
|
|
|
dist.all_reduce(v_threshold_grad, op=dist.ReduceOp.SUM) |
|
|
|
except: |
|
|
|
raise RuntimeWarning('Something wrong with the `all_reduce` operation when summing up the gradient of v_threshold from multiple gpus. Better check the gpu status and try DistributedDataParallel.') |
|
|
|
|
|
|
|
return input_grad, None, v_threshold_grad, None, None, None, None |
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################## |
|
|
|
# OTTT modules |
|
|
|
########################################################################################################## |
|
|
|
|
|
|
|
class OTTTLIFNode(LIFNode): |
|
|
|
def __init__(self, tau: float = 2., decay_input: bool = False, v_threshold: float = 1., |
|
|
|
v_reset: float = None, surrogate_function: Callable = surrogate.Sigmoid(), |
|
|
|
detach_reset: bool = True, step_mode='s', backend='torch', store_v_seq: bool = False): |
|
|
|
""" |
|
|
|
* :ref:`API in English <OTTTLIFNode.__init__-en>` |
|
|
|
|
|
|
|
.. _OTTTLIFNode.__init__-cn: |
|
|
|
|
|
|
|
:param tau: 膜电位时间常数 |
|
|
|
:type tau: float |
|
|
|
|
|
|
|
:param decay_input: 输入是否也会参与衰减 |
|
|
|
:type decay_input: bool |
|
|
|
|
|
|
|
:param v_threshold: 神经元的阈值电压 |
|
|
|
:type v_threshold: float |
|
|
|
|
|
|
|
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; |
|
|
|
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` |
|
|
|
:type v_reset: float |
|
|
|
|
|
|
|
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 |
|
|
|
:type surrogate_function: Callable |
|
|
|
|
|
|
|
:param detach_reset: 是否将reset过程的计算图分离。该参数在本模块中不起作用,仅为保持代码统一而保留 |
|
|
|
:type detach_reset: bool |
|
|
|
|
|
|
|
:param step_mode: 步进模式,为了保证神经元的显存占用小,仅可以为 `'s'` (单步) |
|
|
|
:type step_mode: str |
|
|
|
|
|
|
|
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 |
|
|
|
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 |
|
|
|
:type backend: str |
|
|
|
|
|
|
|
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` |
|
|
|
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 |
|
|
|
通常设置成 ``False`` ,可以节省内存 |
|
|
|
:type store_v_seq: bool |
|
|
|
|
|
|
|
神经元模型出处:`Online Training Through Time for Spiking Neural Networks <https://arxiv.org/pdf/2210.04195.pdf>` |
|
|
|
模型正向传播和Leaky Integrate-and-Fire神经元相同;用于随时间在线训练 |
|
|
|
|
|
|
|
|
|
|
|
* :ref:`中文API <OTTTLIFNode.__init__-cn>` |
|
|
|
|
|
|
|
.. _OTTTLIFNode.__init__-en: |
|
|
|
|
|
|
|
:param tau: membrane time constant |
|
|
|
:type tau: float |
|
|
|
|
|
|
|
:param decay_input: whether the input will decay |
|
|
|
:type decay_input: bool |
|
|
|
|
|
|
|
:param v_threshold: threshold of this neurons layer |
|
|
|
:type v_threshold: float |
|
|
|
|
|
|
|
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` |
|
|
|
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike |
|
|
|
:type v_reset: float |
|
|
|
|
|
|
|
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward |
|
|
|
:type surrogate_function: Callable |
|
|
|
|
|
|
|
:param detach_reset: whether detach the computation graph of reset in backward. this parameter does not take any effect in |
|
|
|
the module, and is retained solely for code consistency |
|
|
|
:type detach_reset: bool |
|
|
|
|
|
|
|
:param step_mode: the step mode, which can solely be `s` (single-step) to guarantee the memory-efficient computation |
|
|
|
:type step_mode: str |
|
|
|
|
|
|
|
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can |
|
|
|
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, |
|
|
|
using ``'cupy'`` backend will have the fastest training speed |
|
|
|
:type backend: str |
|
|
|
|
|
|
|
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls |
|
|
|
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, |
|
|
|
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the |
|
|
|
memory consumption |
|
|
|
:type store_v_seq: bool |
|
|
|
|
|
|
|
OTTT LIF neuron refers to `Online Training Through Time for Spiking Neural Networks <https://arxiv.org/pdf/2210.04195.pdf>` |
|
|
|
The forward propagation is the same as the Leaky Integrate-and-Fire neuron; used for online training through time |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) |
|
|
|
assert step_mode == 's', "Please use single-step mode to enable memory-efficient training." |
|
|
|
""" |
|
|
|
膜电位将在前向传播过程中重新登记为缓存,以支持多卡分布式训练的情况下保留信息在各时刻进行多次反向传播 |
|
|
|
|
|
|
|
membrane potential will be registered as buffer during forward, to support multiple backpropagation for all time steps with |
|
|
|
reserved informtion under distributed training on multiple GPUs |
|
|
|
""" |
|
|
|
self._memories.pop('v') |
|
|
|
|
|
|
|
def reset(self): |
|
|
|
super().reset() |
|
|
|
if hasattr(self, 'v'): |
|
|
|
del self.v |
|
|
|
if hasattr(self, 'trace'): |
|
|
|
del self.trace |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_backends(self): |
|
|
|
if self.step_mode == 's': |
|
|
|
return ('torch') |
|
|
|
else: |
|
|
|
raise ValueError(self.step_mode) |
|
|
|
|
|
|
|
def neuronal_charge(self, x: torch.Tensor): |
|
|
|
self.v = self.v.detach() |
|
|
|
|
|
|
|
if self.decay_input: |
|
|
|
if self.v_reset is None or self.v_reset == 0.: |
|
|
|
self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau) |
|
|
|
else: |
|
|
|
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau) |
|
|
|
|
|
|
|
else: |
|
|
|
if self.v_reset is None or self.v_reset == 0.: |
|
|
|
self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau) |
|
|
|
else: |
|
|
|
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau) |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
@torch.jit.script |
|
|
|
def track_trace(spike: torch.Tensor, trace: torch.Tensor, tau: float): |
|
|
|
with torch.no_grad(): |
|
|
|
trace = trace * (1. - 1. / tau) + spike |
|
|
|
return trace |
|
|
|
|
|
|
|
|
|
|
|
def single_step_forward(self, x: torch.Tensor): |
|
|
|
""" |
|
|
|
训练时,输出脉冲和迹;推理时,输出脉冲 |
|
|
|
训练时需要将后续参数模块用layer.py中定义的GradwithTrace进行包装,根据迹计算梯度 |
|
|
|
|
|
|
|
output spike and trace during training; output spike during inference |
|
|
|
during training, successive parametric modules shoule be wrapped by GradwithTrace defined in layer.py, to calculate gradients with traces |
|
|
|
""" |
|
|
|
|
|
|
|
if not hasattr(self, 'v'): |
|
|
|
if self.v_reset is None: |
|
|
|
self.register_buffer('v', torch.zeros_like(x)) |
|
|
|
else: |
|
|
|
self.register_buffer('v', torch.ones_like(x) * self.v_reset) |
|
|
|
|
|
|
|
if self.training: |
|
|
|
if not hasattr(self, 'trace'): |
|
|
|
self.register_buffer('trace', torch.zeros_like(x)) |
|
|
|
|
|
|
|
if self.backend == 'torch': |
|
|
|
self.neuronal_charge(x) |
|
|
|
spike = self.neuronal_fire() |
|
|
|
self.neuronal_reset(spike) |
|
|
|
|
|
|
|
self.trace = self.track_trace(spike, self.trace, self.tau) |
|
|
|
|
|
|
|
return [spike, self.trace] |
|
|
|
else: |
|
|
|
raise ValueError(self.backend) |
|
|
|
else: |
|
|
|
if self.v_reset is None: |
|
|
|
if self.decay_input: |
|
|
|
spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v, |
|
|
|
self.v_threshold, self.tau) |
|
|
|
else: |
|
|
|
spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v, |
|
|
|
self.v_threshold, |
|
|
|
self.tau) |
|
|
|
else: |
|
|
|
if self.decay_input: |
|
|
|
spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v, |
|
|
|
self.v_threshold, |
|
|
|
self.v_reset, self.tau) |
|
|
|
else: |
|
|
|
spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v, |
|
|
|
self.v_threshold, |
|
|
|
self.v_reset, |
|
|
|
self.tau) |
|
|
|
return spike |
|
|
|
|
|
|
|
|
|
|
|
########################################################################################################## |
|
|
|
# SLTT modules |
|
|
|
########################################################################################################## |
|
|
|
|
|
|
|
class SLTTLIFNode(LIFNode): |
|
|
|
def __init__(self, tau: float = 2., decay_input: bool = True, v_threshold: float = 1., |
|
|
|
v_reset: float = 0., surrogate_function: Callable = surrogate.Sigmoid(), |
|
|
|
detach_reset: bool = True, step_mode='s', backend='torch', store_v_seq: bool = False): |
|
|
|
""" |
|
|
|
* :ref:`API in English <SLTTLIFNode.__init__-en>` |
|
|
|
|
|
|
|
.. _SLTTLIFNode.__init__-cn: |
|
|
|
|
|
|
|
:param tau: 膜电位时间常数 |
|
|
|
:type tau: float |
|
|
|
|
|
|
|
:param decay_input: 输入是否也会参与衰减 |
|
|
|
:type decay_input: bool |
|
|
|
|
|
|
|
:param v_threshold: 神经元的阈值电压 |
|
|
|
:type v_threshold: float |
|
|
|
|
|
|
|
:param v_reset: 神经元的重置电压。如果不为 ``None``,当神经元释放脉冲后,电压会被重置为 ``v_reset``; |
|
|
|
如果设置为 ``None``,当神经元释放脉冲后,电压会被减去 ``v_threshold`` |
|
|
|
:type v_reset: float |
|
|
|
|
|
|
|
:param surrogate_function: 反向传播时用来计算脉冲函数梯度的替代函数 |
|
|
|
:type surrogate_function: Callable |
|
|
|
|
|
|
|
:param detach_reset: 是否将reset过程的计算图分离。该参数在本模块中不起作用,仅为保持代码统一而保留 |
|
|
|
:type detach_reset: bool |
|
|
|
|
|
|
|
:param step_mode: 步进模式,为了保证神经元的显存占用小,仅可以为 `'s'` (单步) |
|
|
|
:type step_mode: str |
|
|
|
|
|
|
|
:param backend: 使用那种后端。不同的 ``step_mode`` 可能会带有不同的后端。可以通过打印 ``self.supported_backends`` 查看当前 |
|
|
|
使用的步进模式支持的后端。在支持的情况下,使用 ``'cupy'`` 后端是速度最快的 |
|
|
|
:type backend: str |
|
|
|
|
|
|
|
:param store_v_seq: 在使用 ``step_mode = 'm'`` 时,给与 ``shape = [T, N, *]`` 的输入后,是否保存中间过程的 ``shape = [T, N, *]`` |
|
|
|
的各个时间步的电压值 ``self.v_seq`` 。设置为 ``False`` 时计算完成后只保留最后一个时刻的电压,即 ``shape = [N, *]`` 的 ``self.v`` 。 |
|
|
|
通常设置成 ``False`` ,可以节省内存 |
|
|
|
:type store_v_seq: bool |
|
|
|
|
|
|
|
神经元模型出处:`Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks |
|
|
|
<https://arxiv.org/pdf/2302.14311.pdf>`.模型正向传播和Leaky Integrate-and-Fire神经元相同. |
|
|
|
|
|
|
|
|
|
|
|
* :ref:`中文API <SLTTLIFNode.__init__-cn>` |
|
|
|
|
|
|
|
.. _SLTTLIFNode.__init__-en: |
|
|
|
|
|
|
|
:param tau: membrane time constant |
|
|
|
:type tau: float |
|
|
|
|
|
|
|
:param decay_input: whether the input will decay |
|
|
|
:type decay_input: bool |
|
|
|
|
|
|
|
:param v_threshold: threshold of this neurons layer |
|
|
|
:type v_threshold: float |
|
|
|
|
|
|
|
:param v_reset: reset voltage of this neurons layer. If not ``None``, the neuron's voltage will be set to ``v_reset`` |
|
|
|
after firing a spike. If ``None``, the neuron's voltage will subtract ``v_threshold`` after firing a spike |
|
|
|
:type v_reset: float |
|
|
|
|
|
|
|
:param surrogate_function: the function for calculating surrogate gradients of the heaviside step function in backward |
|
|
|
:type surrogate_function: Callable |
|
|
|
|
|
|
|
:param detach_reset: whether detach the computation graph of reset in backward. this parameter does not take any effect in |
|
|
|
the module, and is retained solely for code consistency |
|
|
|
:type detach_reset: bool |
|
|
|
|
|
|
|
:param step_mode: the step mode, which can solely be `s` (single-step) to guarantee the memory-efficient computation |
|
|
|
:type step_mode: str |
|
|
|
|
|
|
|
:param backend: backend fot this neurons layer. Different ``step_mode`` may support for different backends. The user can |
|
|
|
print ``self.supported_backends`` and check what backends are supported by the current ``step_mode``. If supported, |
|
|
|
using ``'cupy'`` backend will have the fastest training speed |
|
|
|
:type backend: str |
|
|
|
|
|
|
|
:param store_v_seq: when using ``step_mode = 'm'`` and given input with ``shape = [T, N, *]``, this option controls |
|
|
|
whether storing the voltage at each time-step to ``self.v_seq`` with ``shape = [T, N, *]``. If set to ``False``, |
|
|
|
only the voltage at last time-step will be stored to ``self.v`` with ``shape = [N, *]``, which can reduce the |
|
|
|
memory consumption |
|
|
|
:type store_v_seq: bool |
|
|
|
|
|
|
|
SLTT LIF neuron refers to `Towards Memory- and Time-Efficient Backpropagation for Training Spiking Neural Networks |
|
|
|
<https://arxiv.org/pdf/2302.14311.pdf>`. The forward propagation is the same as the Leaky Integrate-and-Fire neuron's. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
super().__init__(tau, decay_input, v_threshold, v_reset, surrogate_function, detach_reset, step_mode, backend, store_v_seq) |
|
|
|
assert step_mode == 's', "Please use single-step mode to enable memory-efficient training." |
|
|
|
self._memories.pop('v') |
|
|
|
|
|
|
|
def reset(self): |
|
|
|
super().reset() |
|
|
|
if hasattr(self, 'v'): |
|
|
|
del self.v |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_backends(self): |
|
|
|
if self.step_mode == 's': |
|
|
|
return ('torch') |
|
|
|
else: |
|
|
|
raise ValueError(self.step_mode) |
|
|
|
|
|
|
|
def neuronal_charge(self, x: torch.Tensor): |
|
|
|
self.v = self.v.detach() |
|
|
|
|
|
|
|
if self.decay_input: |
|
|
|
if self.v_reset is None or self.v_reset == 0.: |
|
|
|
self.v = self.neuronal_charge_decay_input_reset0(x, self.v, self.tau) |
|
|
|
else: |
|
|
|
self.v = self.neuronal_charge_decay_input(x, self.v, self.v_reset, self.tau) |
|
|
|
|
|
|
|
else: |
|
|
|
if self.v_reset is None or self.v_reset == 0.: |
|
|
|
self.v = self.neuronal_charge_no_decay_input_reset0(x, self.v, self.tau) |
|
|
|
else: |
|
|
|
self.v = self.neuronal_charge_no_decay_input(x, self.v, self.v_reset, self.tau) |
|
|
|
|
|
|
|
def single_step_forward(self, x: torch.Tensor): |
|
|
|
|
|
|
|
if not hasattr(self, 'v'): |
|
|
|
if self.v_reset is None: |
|
|
|
self.register_buffer('v', torch.zeros_like(x)) |
|
|
|
else: |
|
|
|
self.register_buffer('v', torch.ones_like(x) * self.v_reset) |
|
|
|
|
|
|
|
if self.training: |
|
|
|
if self.backend == 'torch': |
|
|
|
self.neuronal_charge(x) |
|
|
|
spike = self.neuronal_fire() |
|
|
|
self.neuronal_reset(spike) |
|
|
|
return spike |
|
|
|
else: |
|
|
|
raise ValueError(self.backend) |
|
|
|
else: |
|
|
|
if self.v_reset is None: |
|
|
|
if self.decay_input: |
|
|
|
spike, self.v = self.jit_eval_single_step_forward_soft_reset_decay_input(x, self.v, |
|
|
|
self.v_threshold, self.tau) |
|
|
|
else: |
|
|
|
spike, self.v = self.jit_eval_single_step_forward_soft_reset_no_decay_input(x, self.v, |
|
|
|
self.v_threshold, |
|
|
|
self.tau) |
|
|
|
else: |
|
|
|
if self.decay_input: |
|
|
|
spike, self.v = self.jit_eval_single_step_forward_hard_reset_decay_input(x, self.v, |
|
|
|
self.v_threshold, |
|
|
|
self.v_reset, self.tau) |
|
|
|
else: |
|
|
|
spike, self.v = self.jit_eval_single_step_forward_hard_reset_no_decay_input(x, self.v, |
|
|
|
self.v_threshold, |
|
|
|
self.v_reset, |
|
|
|
self.tau) |
|
|
|
return spike |