stac

API Reference

Core Classes

TemporalSpikeProcessor

Main class for multi-turn conversational SNN processing.

class TemporalSpikeProcessor(nn.Module):
    def __init__(self, snn_model, T=16, max_context_length=512):
        """
        Initialize the temporal spike processor.
        
        Args:
            snn_model: The converted SNN model
            T: Number of timesteps for spike processing
            max_context_length: Maximum sequence length
        """

Methods

forward(input_ids, attention_mask=None, use_cache=True, **kwargs)

Process input through the SNN with temporal dynamics.

Parameters:

Returns:

reset_cache(batch_id=None)

Reset the KV cache for new conversations.

Parameters:

get_position_ids()

Get current position IDs for the conversation.

Returns:

SpikeAttention

Spiking-compatible attention mechanism.

class SpikeAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, T=16, causal=True):
        """
        Initialize spike-based attention.
        
        Args:
            embed_dim: Embedding dimension
            num_heads: Number of attention heads
            T: Number of timesteps
            causal: Whether to use causal attention
        """

SpikeLayerNorm

Spiking-compatible layer normalization.

class SpikeLayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        """
        Initialize spike-compatible layer normalization.
        
        Args:
            normalized_shape: Input shape to normalize
            eps: Small constant for numerical stability
        """

Conversion Functions

replace_gelu_with_relu(model)

Replace GELU activations with ReLU for SNN compatibility.

Parameters:

Returns:

simplified_conversion(model, timesteps=32)

Perform simplified ANN→SNN conversion.

Parameters:

Returns:

replace_layernorm_with_spikelayernorm(model)

Replace LayerNorm with SpikeLayerNorm.

Parameters:

Returns:

replace_attention_with_spikeattention(model)

Replace standard attention with SpikeAttention.

Parameters:

Returns:

apply_surrogate_gradients(model, alpha=4.0)

Apply surrogate gradient functions for SNN training.

Parameters:

Returns:

calibrate_timesteps(model, original_T, target_T)

Calibrate spike timing for different timestep counts.

Parameters:

Returns:

save_snn_model(model, tokenizer, path)

Save the converted SNN model with metadata.

Parameters:

Returns:

Utility Functions

create_calibration_data(tokenizer, num_samples=10, max_length=128)

Create calibration data for SNN conversion.

Parameters:

Returns:

Testing Functions

test_position_id_boundaries(model, tokenizer, args)

Test position ID handling at sequence boundaries.

Parameters:

Returns:

test_attention_mask_continuity(model, tokenizer, args)

Test attention mask continuity across conversation turns.

Parameters:

Returns:

test_multi_turn_coherence(model, tokenizer, args)

Test multi-turn conversation coherence.

Parameters:

Returns:

simulate_conversation(model, tokenizer, turns=3, device="cpu")

Simulate a multi-turn conversation for testing.

Parameters:

Returns:

Command Line Interface

run_conversion.py

Main CLI tool for model conversion.

Usage:

python run_conversion.py [OPTIONS]

Options:

test_conversational_snn.py

Testing and validation tool.

Usage:

python test_conversational_snn.py [OPTIONS]

Options:

Configuration

Model Parameters

Supported Models:

Conversion Parameters:

Hardware Configuration

GPU Memory Requirements:

CPU Requirements:

Error Handling

Common Exceptions

ImportError: SpikingJelly version compatibility

# Ensure SpikingJelly >= 0.0.0.0.14
pip install spikingjelly[cuda] -U --pre

CUDA Out of Memory: Insufficient GPU memory

# Reduce batch size or use CPU
device = 'cpu'

Position ID Errors: Sequence length exceeds model limits

# Reduce max_context_length
max_context_length = 512

Examples

Basic Conversion

from smollm2_converter import *

# Load model
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

# Convert to SNN
snn_model = simplified_conversion(model, timesteps=16)

# Wrap with temporal processor
processor = TemporalSpikeProcessor(snn_model, T=16)

# Test conversation
result = simulate_conversation(processor, tokenizer, turns=3)

Advanced Usage

# Full pipeline conversion
from convert import convert_model_to_spiking, create_calibration_data

# Create calibration data
calib_data = create_calibration_data(tokenizer, num_samples=10)

# Convert with calibration
snn_model = convert_model_to_spiking(model, calib_data, timesteps=32)

# Apply surrogate gradients
snn_model = apply_surrogate_gradients(snn_model, alpha=4.0)

# Save model
save_snn_model(snn_model, tokenizer, "./my_snn_model")