Main class for multi-turn conversational SNN processing.
class TemporalSpikeProcessor(nn.Module):
def __init__(self, snn_model, T=16, max_context_length=512):
"""
Initialize the temporal spike processor.
Args:
snn_model: The converted SNN model
T: Number of timesteps for spike processing
max_context_length: Maximum sequence length
"""
forward(input_ids, attention_mask=None, use_cache=True, **kwargs)
Process input through the SNN with temporal dynamics.
Parameters:
input_ids
(torch.Tensor): Input token IDsattention_mask
(torch.Tensor, optional): Attention maskuse_cache
(bool): Whether to use KV cache for multi-turn**kwargs
: Additional model argumentsReturns:
reset_cache(batch_id=None)
Reset the KV cache for new conversations.
Parameters:
batch_id
(int, optional): Specific batch to resetget_position_ids()
Get current position IDs for the conversation.
Returns:
Spiking-compatible attention mechanism.
class SpikeAttention(nn.Module):
def __init__(self, embed_dim, num_heads, T=16, causal=True):
"""
Initialize spike-based attention.
Args:
embed_dim: Embedding dimension
num_heads: Number of attention heads
T: Number of timesteps
causal: Whether to use causal attention
"""
Spiking-compatible layer normalization.
class SpikeLayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
"""
Initialize spike-compatible layer normalization.
Args:
normalized_shape: Input shape to normalize
eps: Small constant for numerical stability
"""
replace_gelu_with_relu(model)
Replace GELU activations with ReLU for SNN compatibility.
Parameters:
model
(torch.nn.Module): Model to modifyReturns:
simplified_conversion(model, timesteps=32)
Perform simplified ANN→SNN conversion.
Parameters:
model
(torch.nn.Module): Source modeltimesteps
(int): Number of SNN timestepsReturns:
replace_layernorm_with_spikelayernorm(model)
Replace LayerNorm with SpikeLayerNorm.
Parameters:
model
(torch.nn.Module): Model to modifyReturns:
replace_attention_with_spikeattention(model)
Replace standard attention with SpikeAttention.
Parameters:
model
(torch.nn.Module): Model to modifyReturns:
apply_surrogate_gradients(model, alpha=4.0)
Apply surrogate gradient functions for SNN training.
Parameters:
model
(torch.nn.Module): SNN modelalpha
(float): Surrogate gradient scaling factorReturns:
calibrate_timesteps(model, original_T, target_T)
Calibrate spike timing for different timestep counts.
Parameters:
model
(torch.nn.Module): SNN modeloriginal_T
(int): Original timestep counttarget_T
(int): Target timestep countReturns:
save_snn_model(model, tokenizer, path)
Save the converted SNN model with metadata.
Parameters:
model
(torch.nn.Module): SNN model to savetokenizer
: Associated tokenizerpath
(str): Save pathReturns:
create_calibration_data(tokenizer, num_samples=10, max_length=128)
Create calibration data for SNN conversion.
Parameters:
tokenizer
: HuggingFace tokenizernum_samples
(int): Number of calibration samplesmax_length
(int): Maximum sequence lengthReturns:
test_position_id_boundaries(model, tokenizer, args)
Test position ID handling at sequence boundaries.
Parameters:
model
: SNN model to testtokenizer
: Associated tokenizerargs
: Test configurationReturns:
test_attention_mask_continuity(model, tokenizer, args)
Test attention mask continuity across conversation turns.
Parameters:
model
: SNN model to testtokenizer
: Associated tokenizerargs
: Test configurationReturns:
test_multi_turn_coherence(model, tokenizer, args)
Test multi-turn conversation coherence.
Parameters:
model
: SNN model to testtokenizer
: Associated tokenizerargs
: Test configurationReturns:
simulate_conversation(model, tokenizer, turns=3, device="cpu")
Simulate a multi-turn conversation for testing.
Parameters:
model
: SNN modeltokenizer
: Associated tokenizerturns
(int): Number of conversation turnsdevice
(str): Computing deviceReturns:
run_conversion.py
Main CLI tool for model conversion.
Usage:
python run_conversion.py [OPTIONS]
Options:
--model_name
: Model to convert (distilgpt2, SmolLM2-1.7B-Instruct)--output_dir
: Output directory--timesteps
: Number of SNN timesteps--simplified
: Use simplified conversion--verify
: Run post-conversion verificationtest_conversational_snn.py
Testing and validation tool.
Usage:
python test_conversational_snn.py [OPTIONS]
Options:
--test_all
: Run all tests--test_position_boundaries
: Test position ID boundaries--test_attention_mask
: Test attention mask continuity--test_multi_turn
: Test multi-turn capabilities--test_energy
: Test energy consumptionSupported Models:
distilgpt2
: DistilGPT-2 (117M parameters)SmolLM2-1.7B-Instruct
: SmolLM2 1.7B Instruct (1.7B parameters)Conversion Parameters:
timesteps
: 8-64 (recommended: 16)max_context_length
: 512-2048 (recommended: 512)surrogate_function
: atan, sigmoid, stbif_plusGPU Memory Requirements:
CPU Requirements:
ImportError: SpikingJelly version compatibility
# Ensure SpikingJelly >= 0.0.0.0.14
pip install spikingjelly[cuda] -U --pre
CUDA Out of Memory: Insufficient GPU memory
# Reduce batch size or use CPU
device = 'cpu'
Position ID Errors: Sequence length exceeds model limits
# Reduce max_context_length
max_context_length = 512
from smollm2_converter import *
# Load model
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
# Convert to SNN
snn_model = simplified_conversion(model, timesteps=16)
# Wrap with temporal processor
processor = TemporalSpikeProcessor(snn_model, T=16)
# Test conversation
result = simulate_conversation(processor, tokenizer, turns=3)
# Full pipeline conversion
from convert import convert_model_to_spiking, create_calibration_data
# Create calibration data
calib_data = create_calibration_data(tokenizer, num_samples=10)
# Convert with calibration
snn_model = convert_model_to_spiking(model, calib_data, timesteps=32)
# Apply surrogate gradients
snn_model = apply_surrogate_gradients(snn_model, alpha=4.0)
# Save model
save_snn_model(snn_model, tokenizer, "./my_snn_model")