Install Packages

In [None]:
# --- Installation Cell ---
# Run this cell first to install required libraries in Google Colab
!pip install torch transformers datasets matplotlib torchinfo tqdm accelerate -U -q
# accelerate is included as it's often useful with Hugging Face libraries
# -U ensures upgrading to latest compatible versions
# -q makes the installation quieter
print("Required libraries installed/updated.")

Run Tests

In [None]:
# --- Test Cell ---
# Run this cell before the main script to check component integrity.

import torch
import torch.nn as nn
from torch.optim import AdamW  # <--- CORRECTED IMPORT
from torch.utils.data import DataLoader, Dataset
from transformers import GPT2Tokenizer, GPT2Model, get_linear_schedule_with_warmup  # <--- AdamW REMOVED
import numpy as np
import os
import tempfile  # For creating temporary directories for checkpoint testing
import shutil  # For cleaning up temporary directories
import copy  # For comparing states after loading checkpoint
import math  # Needed for surrogate spike test
import traceback  # For detailed error printing

# --- Import necessary components from your main script ---
# (These class definitions need to be accessible.)

# --- Fallback: Redefine HPARAMS and necessary classes if not in environment ---
# This makes the test cell more self-contained if run independently
try:
    # Check if definitions exist from the main script environment
    HPARAMS; SurrogateSpikeFunction; DLPFCAdExNeuron; DLPFCLayer; HyperdimensionalMemoryModule; DLPFCTransformer; save_checkpoint; load_checkpoint; initialize_history
    print("Using HPARAMS and classes/functions from main script environment.")
except NameError:
    print("HPARAMS or Classes/Functions not found, defining defaults for testing scope.")
    # --- Minimal HPARAMS for testing ---
    HPARAMS = {
        'model_name': "gpt2",
        'learning_rate': 5e-5,
        'weight_decay': 0.01,
        'l1_lambda': 1e-5,
        'num_epochs': 1,
        'batch_size': 2,
        'seq_length': 16,
        'num_recurrent_layers': 1,
        'dlpfc_output_size': 8,
        'adex_params': {
            'tau_m': 20.0,
            'tau_w': 144.0,
            'a': 4.0,
            'b': 0.08,
            'V_th': -50.0,
            'V_reset': -70.0,
            'V_rest': -65.0,
            'delta_T': 2.0
        },
        'dropout_prob': 0.1,
        'warmup_steps': 10,
        'hdm_dim': 16,
        'hdm_hidden_dim': 8,
        'log_interval': 10,
        'output_dir': os.path.join(tempfile.gettempdir(), "test_dlpfc_output"),
        'checkpoint_filename': "checkpoint.pth",
        'best_model_filename': "best_model_state.pth",
        'final_model_filename': "final_model_state.pth",
        'history_filename': "training_history.json",
        'hparams_filename': "hparams.json",
        'seed': 42
    }

    # --- Minimal Class Definitions Needed (CORRECT MULTI-LINE __INIT__ SYNTAX) ---

    class SurrogateSpikeFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input_tensor):
            ctx.save_for_backward(input_tensor)
            return (input_tensor > 0).float()

        @staticmethod
        def backward(ctx, grad_output):
            (input_tensor,) = ctx.saved_tensors
            spike_pseudo_grad = torch.exp(-(input_tensor**2) / 2.0) / math.sqrt(2 * math.pi)
            return grad_output * spike_pseudo_grad

    surrogate_spike = SurrogateSpikeFunction.apply

    class DLPFCAdExNeuron(nn.Module):
        def __init__(self, **adex_params):
            super().__init__()
            self.tau_m    = nn.Parameter(torch.tensor(adex_params.get('tau_m', 20.0)))
            self.tau_w    = nn.Parameter(torch.tensor(adex_params.get('tau_w', 144.0)))
            self.a        = nn.Parameter(torch.tensor(adex_params.get('a', 4.0)))
            self.b        = nn.Parameter(torch.tensor(adex_params.get('b', 0.08)))
            self.V_th     = nn.Parameter(torch.tensor(adex_params.get('V_th', -50.0)), requires_grad=False)
            self.V_reset  = nn.Parameter(torch.tensor(adex_params.get('V_reset', -70.0)), requires_grad=False)
            self.V_rest   = nn.Parameter(torch.tensor(adex_params.get('V_rest', -65.0)), requires_grad=False)
            self.delta_T  = nn.Parameter(torch.tensor(adex_params.get('delta_T', 2.0)))

        def forward(self, input_current, V, w):
            dt = 1.0
            exp_term = torch.exp((V - self.V_th) / self.delta_T).clamp(max=50.0)
            dV = (dt / self.tau_m) * (-(V - self.V_rest) + self.delta_T * exp_term - w + input_current)
            V_new = V + dV
            dw = (dt / self.tau_w) * (self.a * (V - self.V_rest) - w)
            w_new = w + dw
            spike = surrogate_spike(V_new - self.V_th)
            V_final = torch.where(spike > 0.5, self.V_reset, V_new)
            w_final = w_new + self.b * spike
            return spike, V_final, w_final

    class DLPFCLayer(nn.Module):
        def __init__(self, input_size, output_size, num_recurrent_layers=1, adex_params=None, dropout_prob=0.1):
            super().__init__()
            self.output_size = output_size
            self.num_recurrent_layers = num_recurrent_layers
            if adex_params is None:
                adex_params = {}
            self.projection = nn.Linear(input_size, output_size)
            self.adex0 = DLPFCAdExNeuron(**adex_params)
            self.recurrent_projections = nn.ModuleList([
                nn.Linear(output_size, output_size) for _ in range(num_recurrent_layers)
            ])
            self.recurrent_neurons = nn.ModuleList([
                DLPFCAdExNeuron(**adex_params) for _ in range(num_recurrent_layers)
            ])
            self.dropout = nn.Dropout(p=dropout_prob)

        def forward(self, hidden_states):
            batch_size, seq_len, _ = hidden_states.size()
            device = hidden_states.device
            V0 = torch.full((batch_size, self.output_size), self.adex0.V_reset.item(), device=device)
            w0 = torch.zeros(batch_size, self.output_size, device=device)
            V_rec = [torch.full((batch_size, self.output_size), l.V_reset.item(), device=device) for l in self.recurrent_neurons]
            w_rec = [torch.zeros(batch_size, self.output_size, device=device) for _ in self.recurrent_neurons]
            spk_list = []
            for t in range(seq_len):
                x_t = hidden_states[:, t, :]
                current = self.projection(x_t)
                spk0, V0, w0 = self.adex0(current, V0, w0)
                spk_out = self.dropout(spk0)
                spk_rec_input = spk_out
                for i in range(self.num_recurrent_layers):
                    rec_current = self.recurrent_projections[i](spk_rec_input)
                    spk_rec, V_rec[i], w_rec[i] = self.recurrent_neurons[i](rec_current, V_rec[i], w_rec[i])
                    spk_rec_input = self.dropout(spk_rec)
                spk_list.append(spk_rec_input.unsqueeze(1))
            return torch.cat(spk_list, dim=1)

    class HyperdimensionalMemoryModule(nn.Module):
        def __init__(self, input_dim, hdm_dim, output_dim):
            super().__init__()
            self.register_buffer("proj_matrix", torch.randn(input_dim, hdm_dim))
            self.mlp = nn.Sequential(
                nn.Linear(hdm_dim, hdm_dim // 2),
                nn.ReLU(),
                nn.Linear(hdm_dim // 2, output_dim)
            )

        def forward(self, spike_train):
            pooled_spikes = torch.mean(spike_train, dim=1)
            hdm_vector = torch.matmul(pooled_spikes, self.proj_matrix)
            memory_bias = self.mlp(hdm_vector)
            return memory_bias

    class DLPFCTransformer(nn.Module):
        def __init__(self, hparams):
            super().__init__()
            self.hparams = hparams
            self.gpt2 = GPT2Model.from_pretrained(hparams['model_name'])
            gpt2_hidden_size = self.gpt2.config.hidden_size
            dlpfc_output_size = hparams['dlpfc_output_size']
            self.dlpfc = DLPFCLayer(
                gpt2_hidden_size, dlpfc_output_size,
                hparams['num_recurrent_layers'], hparams['adex_params'], hparams['dropout_prob']
            )
            self.memory_module = HyperdimensionalMemoryModule(
                dlpfc_output_size, hparams['hdm_dim'], dlpfc_output_size
            )
            self.dropout = nn.Dropout(p=hparams['dropout_prob'])
            self.layer_norm = nn.LayerNorm(dlpfc_output_size)
            self.lm_head = nn.Linear(dlpfc_output_size, self.gpt2.config.vocab_size)

        def forward(self, input_ids, attention_mask=None):
            gpt_out = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
            last_hidden = gpt_out.last_hidden_state
            spk_trains = self.dlpfc(last_hidden)
            memory_bias = self.memory_module(spk_trains)
            memory_bias_unsqueezed = memory_bias.unsqueeze(1)
            combined_repr = spk_trains + memory_bias_unsqueezed
            combined_repr_norm = self.layer_norm(combined_repr)
            combined_repr_drop = self.dropout(combined_repr_norm)
            logits = self.lm_head(combined_repr_drop)
            return logits, spk_trains

    # --- Utility Functions Needed for Checkpoint Test ---
    def save_checkpoint(state, filename):
        tmp_filename = filename + ".tmp"
        try:
            torch.save(state, tmp_filename)
            os.rename(tmp_filename, filename)
            print(f"Checkpoint saved to '{filename}' (Epoch {state.get('epoch','N/A')})")
        except Exception as e:
            print(f"Error saving checkpoint: {e}")
            if os.path.exists(tmp_filename):
                os.remove(tmp_filename)

    def load_checkpoint(checkpoint_path, model, optimizer, scheduler, device):
        if os.path.exists(checkpoint_path):
            print(f"Loading checkpoint from '{checkpoint_path}'")
            try:
                checkpoint = torch.load(checkpoint_path, map_location='cpu')
                model.load_state_dict(checkpoint['model_state_dict'])
                model.to(device)
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(device)
                start_epoch = checkpoint['epoch'] + 1
                best_val_loss = checkpoint.get('best_val_loss', float('inf'))
                training_history = checkpoint.get('training_history', initialize_history())
                print(f"Resuming training from epoch {start_epoch}")
                return start_epoch, best_val_loss, training_history
            except Exception as e:
                import traceback
                print(f"Error loading checkpoint: {e}. Starting fresh.")
                traceback.print_exc()
                return 0, float('inf'), initialize_history()
        else:
            print("No checkpoint found. Starting training from scratch.")
            return 0, float('inf'), initialize_history()

    def initialize_history():
        return {
            'epoch': [],
            'train_loss': [],
            'train_perplexity': [],
            'train_l1_loss': [],
            'val_loss': [],
            'val_perplexity': [],
            'val_l1_loss': []
        }

print("--- Setting up Tests ---")

# Use smaller HPARAMS for testing
TEST_HPARAMS = copy.deepcopy(HPARAMS)
TEST_HPARAMS.update({
    'batch_size': 2,
    'seq_length': 16,
    'dlpfc_output_size': 8,
    'hdm_dim': 16,
    'num_recurrent_layers': 1,
    'num_epochs': 1,
    'output_dir': os.path.join(tempfile.gettempdir(), "test_dlpfc_output")
})

# Determine device for testing
test_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Testing on device: {test_device}")

os.makedirs(TEST_HPARAMS['output_dir'], exist_ok=True)

# --- Test Functions ---

def test_surrogate_spike():
    print("Testing SurrogateSpikeFunction...")
    input_tensor = torch.randn(5, requires_grad=True, device=test_device) * 0.5  # Leaf tensor
    spikes = surrogate_spike(input_tensor)  # Non-leaf tensor
    assert spikes.shape == input_tensor.shape, "Forward shape mismatch"
    assert spikes.dtype == torch.float, "Forward output dtype mismatch"
    assert torch.all((spikes == 0) | (spikes == 1)), "Forward output not 0 or 1"
    print("  Forward pass OK.")
    dummy_grad = torch.ones_like(spikes)
    try:
        spikes.backward(dummy_grad)
        print("  Backward pass executed without error.")
    except Exception as e:
        raise AssertionError(f"Backward pass failed with error: {e}")
    # Check gradient properties on the leaf tensor AFTER backward pass
    if input_tensor.grad is not None:
        assert input_tensor.grad.shape == input_tensor.shape, f"Backward grad shape mismatch: {input_tensor.grad.shape}"
        assert input_tensor.grad.dtype == input_tensor.dtype, f"Backward grad dtype mismatch: {input_tensor.grad.dtype}"
        print("  Gradient shape and type on leaf tensor OK.")
    else:
        print("  Warning: Gradient on leaf tensor is None after backward, but backward executed.")
    print("SurrogateSpikeFunction Test PASSED.")

def test_adex_neuron():
    print("Testing DLPFCAdExNeuron...")
    batch_size = TEST_HPARAMS['batch_size']
    output_size = TEST_HPARAMS['dlpfc_output_size']
    neuron = DLPFCAdExNeuron(**TEST_HPARAMS['adex_params']).to(test_device)
    input_current = torch.randn(batch_size, output_size, device=test_device) * 10
    V_init = torch.full((batch_size, output_size), neuron.V_reset.item(), device=test_device)
    w_init = torch.zeros(batch_size, output_size, device=test_device)
    spike, V_next, w_next = neuron(input_current, V_init, w_init)
    assert spike.shape == (batch_size, output_size), f"Spike shape: {spike.shape}"
    assert V_next.shape == (batch_size, output_size), f"V_next shape: {V_next.shape}"
    assert w_next.shape == (batch_size, output_size), f"w_next shape: {w_next.shape}"
    assert spike.dtype == torch.float
    assert V_next.dtype == torch.float
    assert w_next.dtype == torch.float
    print("  Output shapes and dtypes OK.")
    params = list(neuron.parameters())
    assert len(params) > 0, "No parameters registered"
    print(f"  Expected device type: {test_device.type}, index: {test_device.index}")
    all_on_device = True
    for name, p in neuron.named_parameters():
        param_device = p.device
        print(f"  Param '{name}' device: {param_device}")
        if param_device.type != test_device.type:
            all_on_device = False
            print(f"  !!! Type mismatch for '{name}': {param_device.type} != {test_device.type}")
            break
        if test_device.type == 'cuda':
            expected_index = test_device.index if test_device.index is not None else 0
            actual_index = p.device.index if p.device.index is not None else 0
            if expected_index != actual_index:
                all_on_device = False
                print(f"  !!! Index mismatch for '{name}': {actual_index} != {expected_index}")
                break
    assert all_on_device, "One or more parameters were not moved to the correct device"
    print("  Parameters registered and on correct device.")
    print("DLPFCAdExNeuron Test PASSED.")

def test_dlpfc_layer():
    print("Testing DLPFCLayer...")
    batch_size = TEST_HPARAMS['batch_size']
    seq_len = TEST_HPARAMS['seq_length']
    try:
        gpt2_config = GPT2Model.from_pretrained(TEST_HPARAMS['model_name']).config
        input_size = gpt2_config.hidden_size
    except Exception as e:
        print(f"Warning: Could not load GPT2 config, using default size 768. Error: {e}")
        input_size = 768
    output_size = TEST_HPARAMS['dlpfc_output_size']
    layer = DLPFCLayer(input_size, output_size, TEST_HPARAMS['num_recurrent_layers'],
                        TEST_HPARAMS['adex_params'], TEST_HPARAMS['dropout_prob']).to(test_device)
    layer.eval()
    hidden_states = torch.randn(batch_size, seq_len, input_size, device=test_device)
    with torch.no_grad():
        spk_trains = layer(hidden_states)
    expected_shape = (batch_size, seq_len, output_size)
    assert spk_trains.shape == expected_shape, f"Output shape mismatch: {spk_trains.shape} vs {expected_shape}"
    assert spk_trains.dtype == torch.float, f"Output dtype mismatch: {spk_trains.dtype}"
    print("  Output shape and dtype OK.")
    print("DLPFCLayer Test PASSED.")

def test_memory_module():
    print("Testing HyperdimensionalMemoryModule...")
    batch_size = TEST_HPARAMS['batch_size']
    seq_len = TEST_HPARAMS['seq_length']
    input_dim = TEST_HPARAMS['dlpfc_output_size']
    hdm_dim = TEST_HPARAMS['hdm_dim']
    output_dim = TEST_HPARAMS['dlpfc_output_size']
    module = HyperdimensionalMemoryModule(input_dim, hdm_dim, output_dim).to(test_device)
    module.eval()
    spike_train = torch.randint(0, 2, (batch_size, seq_len, input_dim), dtype=torch.float, device=test_device)
    with torch.no_grad():
        memory_bias = module(spike_train)
    expected_shape = (batch_size, output_dim)
    assert memory_bias.shape == expected_shape, f"Output shape mismatch: {memory_bias.shape} vs {expected_shape}"
    assert memory_bias.dtype == torch.float, f"Output dtype mismatch: {memory_bias.dtype}"
    print("  Output shape and dtype OK.")
    print("HyperdimensionalMemoryModule Test PASSED.")

def test_dlpfc_transformer():
    print("Testing DLPFCTransformer (Full Model Forward Pass)...")
    batch_size = TEST_HPARAMS['batch_size']
    seq_len = TEST_HPARAMS['seq_length']
    try:
        model = DLPFCTransformer(TEST_HPARAMS).to(test_device)
        model.eval()
        vocab_size = model.gpt2.config.vocab_size
    except Exception as e:
        raise AssertionError(f"Failed to instantiate DLPFCTransformer for test: {e}")
    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long, device=test_device)
    attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long, device=test_device)
    with torch.no_grad():
        logits, spk_trains = model(input_ids, attention_mask=attention_mask)
    expected_logits_shape = (batch_size, seq_len, vocab_size)
    expected_spk_trains_shape = (batch_size, seq_len, TEST_HPARAMS['dlpfc_output_size'])
    assert logits.shape == expected_logits_shape, f"Logits shape mismatch: {logits.shape} vs {expected_logits_shape}"
    assert spk_trains.shape == expected_spk_trains_shape, f"Spike trains shape mismatch: {spk_trains.shape} vs {expected_spk_trains_shape}"
    assert logits.dtype == torch.float
    assert spk_trains.dtype == torch.float
    print("  Output shapes and dtypes OK.")
    try:
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., 1:].contiguous()
        criterion = nn.CrossEntropyLoss()
        loss_xent = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        loss_l1 = TEST_HPARAMS['l1_lambda'] * torch.mean(torch.abs(spk_trains))
        total_loss = loss_xent + loss_l1
        print("  Loss calculation structure compatible with output shapes.")
    except Exception as e:
        raise AssertionError(f"Failed during simulated loss calculation: {e}")
    print("DLPFCTransformer Test PASSED.")

def test_data_pipeline():
    print("Testing Data Pipeline (Tokenization and DataLoader)...")
    dummy_texts = ["Sentence one.", "Sentence two is longer.", "Short.", "=Title="]
    dummy_texts_filtered = [text for text in dummy_texts if len(text.strip()) > 0]
    class DummyTextDataset(Dataset):
        def __init__(self, texts):
            self.texts = texts
        def __len__(self):
            return len(self.texts)
        def __getitem__(self, idx):
            return {"text": self.texts[idx]}
    dummy_dataset = DummyTextDataset(dummy_texts_filtered)
    try:
        tokenizer = GPT2Tokenizer.from_pretrained(TEST_HPARAMS['model_name'])
    except Exception as e:
        raise AssertionError(f"Failed to load tokenizer for test: {e}")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    def tokenize_function_test(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=TEST_HPARAMS['seq_length'])
    tokenized_data = [tokenize_function_test({"text": t}) for t in dummy_dataset.texts]
    for item in tokenized_data:
        item['input_ids'] = torch.tensor(item['input_ids'], dtype=torch.long)
        item['attention_mask'] = torch.tensor(item['attention_mask'], dtype=torch.long)
    test_loader = DataLoader(tokenized_data, batch_size=TEST_HPARAMS['batch_size'])
    try:
        batch = next(iter(test_loader))
    except Exception as e:
        raise AssertionError(f"Failed to get batch from DataLoader: {e}")
    assert 'input_ids' in batch and 'attention_mask' in batch
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    expected_batch_size = min(TEST_HPARAMS['batch_size'], len(tokenized_data))
    expected_shape = (expected_batch_size, TEST_HPARAMS['seq_length'])
    assert input_ids.shape == expected_shape, f"Batch input_ids shape: {input_ids.shape} vs {expected_shape}"
    assert attention_mask.shape == expected_shape, f"Batch attention_mask shape: {attention_mask.shape} vs {expected_shape}"
    assert input_ids.dtype == torch.long and attention_mask.dtype == torch.long
    print("  Tokenization and DataLoader batch structure OK.")
    print("Data Pipeline Test PASSED.")

def test_checkpointing():
    print("Testing Checkpointing (Save/Load)...")
    test_dir = TEST_HPARAMS['output_dir']
    checkpoint_path = os.path.join(test_dir, TEST_HPARAMS['checkpoint_filename'])
    try:
        model_orig = DLPFCTransformer(TEST_HPARAMS).to(test_device)
    except Exception as e:
        raise AssertionError(f"Failed to instantiate model for checkpoint test: {e}")
    optimizer_orig = AdamW(model_orig.parameters(), lr=TEST_HPARAMS['learning_rate'])
    scheduler_orig = get_linear_schedule_with_warmup(optimizer_orig, num_warmup_steps=10, num_training_steps=100)
    epoch_orig = 3
    best_val_loss_orig = 0.123
    history_orig = {
        'epoch': [1, 2, 3],
        'val_loss': [0.5, 0.3, 0.123],
        'train_loss': [],
        'train_perplexity': [],
        'train_l1_loss': [],
        'val_perplexity': [],
        'val_l1_loss': []
    }
    optimizer_orig.step()
    scheduler_orig.step()
    optimizer_orig.step()
    scheduler_orig.step()
    state_orig = {
        'epoch': epoch_orig,
        'model_state_dict': model_orig.state_dict(),
        'optimizer_state_dict': optimizer_orig.state_dict(),
        'scheduler_state_dict': scheduler_orig.state_dict(),
        'best_val_loss': best_val_loss_orig,
        'training_history': history_orig,
        'hparams': TEST_HPARAMS
    }
    try:
        save_checkpoint(state_orig, checkpoint_path)
        assert os.path.exists(checkpoint_path), "Checkpoint file not created"
        print("  Save checkpoint OK.")
    except Exception as e:
        raise AssertionError(f"Failed to save checkpoint: {e}")
    model_new = DLPFCTransformer(TEST_HPARAMS).to(test_device)
    optimizer_new = AdamW(model_new.parameters(), lr=TEST_HPARAMS['learning_rate'])
    scheduler_new = get_linear_schedule_with_warmup(optimizer_new, num_warmup_steps=10, num_training_steps=100)
    try:
        start_epoch, best_val_loss_loaded, history_loaded = load_checkpoint(checkpoint_path, model_new, optimizer_new, scheduler_new, test_device)
        print("  Load checkpoint function ran without error.")
    except Exception as e:
        raise AssertionError(f"Failed to load checkpoint: {e}")
    assert start_epoch == epoch_orig + 1, f"Loaded start_epoch mismatch: {start_epoch} vs {epoch_orig + 1}"
    assert best_val_loss_loaded == best_val_loss_orig, f"Loaded best_val_loss mismatch: {best_val_loss_loaded} vs {best_val_loss_orig}"
    assert history_loaded == history_orig, "Loaded training_history mismatch"
    print("  Loaded epoch, best_val_loss, history OK.")
    orig_params = list(model_orig.parameters())
    new_params = list(model_new.parameters())
    assert len(orig_params) == len(new_params) and len(orig_params) > 0, "Model param list length mismatch or empty model"
    assert torch.equal(orig_params[0], new_params[0]), "Model state mismatch (first param)"
    assert torch.equal(orig_params[-1], new_params[-1]), "Model state mismatch (last param)"
    print("  Model state loaded OK (checked params).")
    assert len(optimizer_new.param_groups) == len(optimizer_orig.param_groups), "Optimizer param_groups length mismatch"
    assert scheduler_new.state_dict()['last_epoch'] == scheduler_orig.state_dict()['last_epoch'], "Scheduler state mismatch (last_epoch)"
    print("  Optimizer/Scheduler states loaded OK.")
    print("Checkpointing Test PASSED.")

# --- Test Runner ---
def run_all_tests():
    print("\n--- Running All Tests ---")
    tests_passed = 0
    tests_failed = 0
    test_functions = [
        test_surrogate_spike,
        test_adex_neuron,
        test_dlpfc_layer,
        test_memory_module,
        test_dlpfc_transformer,
        test_data_pipeline,
        test_checkpointing
    ]
    all_definitions_found = True
    try:
        HPARAMS
        DLPFCAdExNeuron
    except NameError:
        all_definitions_found = False
    if not all_definitions_found:
        print("\nWARNING: Running tests using fallback definitions.\n")
    for test_func in test_functions:
        try:
            test_func()
            tests_passed += 1
        except AssertionError as e:
            print(f"!!! Test Failed: {test_func.__name__} !!!\n  Error: {e}")
            tests_failed += 1
        except Exception as e:
            import traceback
            print(f"!!! Test Errored: {test_func.__name__} !!!\n  Unexpected Error: {e}")
            traceback.print_exc()
            tests_failed += 1
        print("-" * 30)
    print("\n--- Test Summary ---")
    print(f"Tests Passed: {tests_passed}")
    print(f"Tests Failed: {tests_failed}")
    print("--- End of Tests ---")
    # Clean up test directory - uncomment if desired after successful runs
    # try:
    #     shutil.rmtree(TEST_HPARAMS['output_dir'], ignore_errors=True)
    #     print(f"Cleaned up test directory: {TEST_HPARAMS['output_dir']}")
    # except Exception as e:
    #     print(f"Could not clean up test directory: {e}")
    if tests_failed == 0:
        print("All tests passed successfully!")
    else:
        print("Some tests failed. Please review the errors above.")

# --- Execute Tests ---
run_all_tests()


Run Training

In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer, GPT2Model, get_linear_schedule_with_warmup
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import json
import math  # For perplexity calculation
import time  # For timing epochs
from torchinfo import summary  # For model summary
import shutil  # For potentially copying best model checkpoint
import copy  # For deepcopying HPARAMS
import traceback  # For detailed error printing

# --------------------------------------------------------------------------------
# HYPERPARAMETERS (Defaults) -
# --------------------------------------------------------------------------------
HPARAMS = {
    'model_name': "gpt2",         # GPT-2 base
    'learning_rate': 5e-5,
    'weight_decay': 0.01,
    'l1_lambda': 1e-5,            # Penalty on SNN spike activity
    'num_epochs': 3,              # Total number of epochs to train for
    'batch_size': 8,              # Adjust based on GPU memory (e.g., 4, 8, 16)
    'seq_length': 128,            # Max seq length
    'num_recurrent_layers': 1,    # Recurrent spiking layers in "DLPFC"
    'dlpfc_output_size': 512,     # Spiking neurons output dimension
    'adex_params': {
        'tau_m': 20.0, 'tau_w': 144.0, 'a': 4.0, 'b': 0.08,
        'V_th': -50.0, 'V_reset': -70.0, 'V_rest': -65.0, 'delta_T': 2.0
    },
    'dropout_prob': 0.2,
    'warmup_steps': 500,
    'hdm_dim': 1024,              # Dimension of the high-dimensional space
    'hdm_hidden_dim': 512,        # Not directly used in current simple MLP
    'log_interval': 100,          # Log training progress every N steps
    'output_dir': "dlpfc_spiking_gpt2_output",  # !!! IMPORTANT: Mount Google Drive and point this path there for persistence !!!
    'checkpoint_filename': "checkpoint.pth",    # Name for the resume checkpoint file
    'best_model_filename': "best_model_state.pth",  # Name for the best model state file
    'final_model_filename': "final_model_state.pth",  # Name for the final model state file
    'history_filename': "training_history.json",  # Name for the training history file
    'hparams_filename': "hparams.json",           # Name for the hyperparameters file
    'seed': 42  # Random seed for reproducibility
}  # <--- Closing brace for HPARAMS dictionary

# --------------------------------------------------------------------------------
# Utility Functions
# --------------------------------------------------------------------------------
def set_seed(seed_value):
    """Sets the seed for reproducibility."""
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
    print(f"Set random seed to {seed_value}")

def save_checkpoint(state, filename):
    """Saves checkpoint using atomic write."""
    tmp_filename = filename + ".tmp"
    try:
        torch.save(state, tmp_filename)
        os.rename(tmp_filename, filename)  # Atomic rename
    except Exception as e:
        print(f"Error saving checkpoint '{filename}': {e}")
        if os.path.exists(tmp_filename):
            try:
                os.remove(tmp_filename)  # Clean up temp file on error
            except OSError:
                pass  # Ignore error if removal fails

def load_checkpoint(checkpoint_path, model, optimizer, scheduler, device):
    """Loads checkpoint. Loads to CPU first then moves model to device."""
    if os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from '{checkpoint_path}'")
        try:
            # Load first onto CPU to avoid GPU memory issues
            checkpoint = torch.load(checkpoint_path, map_location='cpu')

            model.load_state_dict(checkpoint['model_state_dict'])
            model.to(device)  # Move model to the correct device *after* loading state_dict

            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

            # Manually move optimizer states to the correct device
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)

            start_epoch = checkpoint['epoch'] + 1  # Start from the next epoch
            best_val_loss = checkpoint.get('best_val_loss', float('inf'))
            training_history = checkpoint.get('training_history', initialize_history())

            print(f"Resuming training from epoch {start_epoch}")
            return start_epoch, best_val_loss, training_history
        except FileNotFoundError:
            print(f"Checkpoint file not found at '{checkpoint_path}'. Starting fresh.")
            return 0, float('inf'), initialize_history()
        except KeyError as e:
            print(f"Error loading state from checkpoint: Missing key {e}. Checkpoint might be incompatible. Starting fresh.")
            return 0, float('inf'), initialize_history()
        except Exception as e:
            print(f"Error loading checkpoint: {e}. Starting fresh.")
            traceback.print_exc()
            return 0, float('inf'), initialize_history()
    else:
        print(f"No checkpoint found at '{checkpoint_path}'. Starting training from scratch.")
        return 0, float('inf'), initialize_history()

def initialize_history():
    return {
        'epoch': [],
        'train_loss': [],
        'train_perplexity': [],
        'train_l1_loss': [],
        'val_loss': [],
        'val_perplexity': [],
        'val_l1_loss': [],
    }

# --------------------------------------------------------------------------------
# 1) Surrogate Spike Function
# --------------------------------------------------------------------------------
class SurrogateSpikeFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_tensor):
        ctx.save_for_backward(input_tensor)
        return (input_tensor > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        (input_tensor,) = ctx.saved_tensors
        # Gaussian surrogate gradient
        spike_pseudo_grad = torch.exp(-(input_tensor**2) / 2.0) / math.sqrt(2 * math.pi)
        return grad_output * spike_pseudo_grad

surrogate_spike = SurrogateSpikeFunction.apply

# --------------------------------------------------------------------------------
# 2) DLPFC AdEx Neuron
# --------------------------------------------------------------------------------
class DLPFCAdExNeuron(nn.Module):
    """Minimal AdEx spiking neuron."""
    def __init__(self, **adex_params):
        super().__init__()
        self.tau_m = nn.Parameter(torch.tensor(adex_params.get('tau_m', 20.0)))
        self.tau_w = nn.Parameter(torch.tensor(adex_params.get('tau_w', 144.0)))
        self.a = nn.Parameter(torch.tensor(adex_params.get('a', 4.0)))
        self.b = nn.Parameter(torch.tensor(adex_params.get('b', 0.08)))
        self.V_th = nn.Parameter(torch.tensor(adex_params.get('V_th', -50.0)), requires_grad=False)
        self.V_reset = nn.Parameter(torch.tensor(adex_params.get('V_reset', -70.0)), requires_grad=False)
        self.V_rest = nn.Parameter(torch.tensor(adex_params.get('V_rest', -65.0)), requires_grad=False)
        self.delta_T = nn.Parameter(torch.tensor(adex_params.get('delta_T', 2.0)))

    def forward(self, input_current, V, w):
        dt = 1.0  # Assumed time step
        exp_term = torch.exp((V - self.V_th) / self.delta_T).clamp(max=50.0)  # Stability clamp
        dV = (dt / self.tau_m) * (-(V - self.V_rest) + self.delta_T * exp_term - w + input_current)
        V_new = V + dV
        dw = (dt / self.tau_w) * (self.a * (V - self.V_rest) - w)
        w_new = w + dw
        spike = surrogate_spike(V_new - self.V_th)
        V_final = torch.where(spike > 0.5, self.V_reset, V_new)
        w_final = w_new + self.b * spike
        return spike, V_final, w_final

# --------------------------------------------------------------------------------
# 3) DLPFCLayer
# --------------------------------------------------------------------------------
class DLPFCLayer(nn.Module):
    """Processes hidden states sequentially through AdEx neurons."""
    def __init__(self, input_size, output_size, num_recurrent_layers=1, adex_params=None, dropout_prob=0.1):
        super().__init__()
        self.output_size = output_size
        self.num_recurrent_layers = num_recurrent_layers
        if adex_params is None:
            adex_params = {}
        self.projection = nn.Linear(input_size, output_size)
        self.adex0 = DLPFCAdExNeuron(**adex_params)
        self.recurrent_projections = nn.ModuleList([
            nn.Linear(output_size, output_size) for _ in range(num_recurrent_layers)
        ])
        self.recurrent_neurons = nn.ModuleList([
            DLPFCAdExNeuron(**adex_params) for _ in range(num_recurrent_layers)
        ])
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, hidden_states):
        batch_size, seq_len, _ = hidden_states.size()
        device = hidden_states.device
        # Initialize states
        V0 = torch.full((batch_size, self.output_size), self.adex0.V_reset.item(), device=device)
        w0 = torch.zeros(batch_size, self.output_size, device=device)
        V_rec = [torch.full((batch_size, self.output_size), l.V_reset.item(), device=device) for l in self.recurrent_neurons]
        w_rec = [torch.zeros(batch_size, self.output_size, device=device) for _ in self.recurrent_neurons]
        spk_list = []
        # Iterate through sequence (time steps)
        for t in range(seq_len):
            x_t = hidden_states[:, t, :]
            current = self.projection(x_t)
            spk0, V0, w0 = self.adex0(current, V0, w0)
            spk_out = self.dropout(spk0)
            spk_rec_input = spk_out
            # Recurrent layers
            for i in range(self.num_recurrent_layers):
                rec_current = self.recurrent_projections[i](spk_rec_input)
                spk_rec, V_rec[i], w_rec[i] = self.recurrent_neurons[i](rec_current, V_rec[i], w_rec[i])
                spk_rec_input = self.dropout(spk_rec)  # Output of last recurrent layer
            spk_list.append(spk_rec_input.unsqueeze(1))
        return torch.cat(spk_list, dim=1)  # [batch, seq_len, output_size]

# --------------------------------------------------------------------------------
# 4) Hyperdimensional Memory Module
# --------------------------------------------------------------------------------
class HyperdimensionalMemoryModule(nn.Module):
    """Encodes spike train into a single memory bias vector."""
    def __init__(self, input_dim, hdm_dim, output_dim):
        super().__init__()
        self.register_buffer("proj_matrix", torch.randn(input_dim, hdm_dim))
        self.mlp = nn.Sequential(
            nn.Linear(hdm_dim, hdm_dim // 2),
            nn.ReLU(),
            nn.Linear(hdm_dim // 2, output_dim)
        )

    def forward(self, spike_train):
        pooled_spikes = torch.mean(spike_train, dim=1)  # [batch, input_dim]
        hdm_vector = torch.matmul(pooled_spikes, self.proj_matrix)  # [batch, hdm_dim]
        memory_bias = self.mlp(hdm_vector)  # [batch, output_dim]
        return memory_bias

# --------------------------------------------------------------------------------
# 5) DLPFCTransformer
# --------------------------------------------------------------------------------
class DLPFCTransformer(nn.Module):
    """Combines GPT-2, DLPFC spiking layer, and HEMM."""
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.gpt2 = GPT2Model.from_pretrained(hparams['model_name'])
        gpt2_hidden_size = self.gpt2.config.hidden_size
        dlpfc_output_size = hparams['dlpfc_output_size']
        self.dlpfc = DLPFCLayer(
            gpt2_hidden_size,
            dlpfc_output_size,
            hparams['num_recurrent_layers'],
            hparams['adex_params'],
            hparams['dropout_prob']
        )
        self.memory_module = HyperdimensionalMemoryModule(
            dlpfc_output_size,
            hparams['hdm_dim'],
            dlpfc_output_size  # Bias dim matches spike dim
        )
        self.dropout = nn.Dropout(p=hparams['dropout_prob'])
        self.layer_norm = nn.LayerNorm(dlpfc_output_size)  # LayerNorm stability
        self.lm_head = nn.Linear(dlpfc_output_size, self.gpt2.config.vocab_size)

    def forward(self, input_ids, attention_mask=None):
        gpt_out = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden = gpt_out.last_hidden_state  # [batch, seq_len, gpt_hidden_size]
        spk_trains = self.dlpfc(last_hidden)  # [batch, seq_len, dlpfc_output_size]
        memory_bias = self.memory_module(spk_trains)  # [batch, dlpfc_output_size]
        # Combine token spikes with memory bias (broadcasted)
        memory_bias_unsqueezed = memory_bias.unsqueeze(1)  # [batch, 1, dlpfc_output_size]
        combined_repr = spk_trains + memory_bias_unsqueezed  # [batch, seq_len, dlpfc_output_size]
        combined_repr_norm = self.layer_norm(combined_repr)
        combined_repr_drop = self.dropout(combined_repr_norm)
        logits = self.lm_head(combined_repr_drop)  # [batch, seq_len, vocab_size]
        return logits, spk_trains

# --------------------------------------------------------------------------------
# 6) Training Function (Modified for Checkpointing)
# --------------------------------------------------------------------------------
def train_model(model, train_loader, val_loader, optimizer, scheduler, device, hparams, start_epoch, best_val_loss, training_history):
    """Trains the model, handling checkpoints and logging."""
    criterion = nn.CrossEntropyLoss()
    log_interval = hparams['log_interval']
    output_dir = hparams['output_dir']
    checkpoint_path = os.path.join(output_dir, hparams['checkpoint_filename'])
    best_model_path = os.path.join(output_dir, hparams['best_model_filename'])
    history_path = os.path.join(output_dir, hparams['history_filename'])
    num_epochs = hparams['num_epochs']

    print(f"\n--- Starting Training (Epochs {start_epoch+1} to {num_epochs}) ---")
    total_start_time = time.time()

    if start_epoch >= num_epochs:
        print(f"Start epoch ({start_epoch}) is >= total epochs ({num_epochs}). Training already completed.")
        return training_history  # Return history without further training

    for epoch in range(start_epoch, num_epochs):
        current_epoch_num = epoch + 1
        epoch_start_time = time.time()
        model.train()  # Set model to training mode
        running_loss, running_l1, steps = 0.0, 0.0, 0
        last_log_time = time.time()

        batch_iterator = tqdm(train_loader, desc=f"Epoch {current_epoch_num}/{num_epochs} Training", leave=False)
        for batch in batch_iterator:
            # Ensure batch items are tensors and move to device
            try:
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)
            except Exception as e:
                print(f"\nError processing batch: {e}")
                print(f"Batch keys: {batch.keys() if isinstance(batch, dict) else 'Not a dict'}")
                if 'input_ids' in batch:
                    print(f"Input IDs type: {type(batch['input_ids'])}")
                if 'attention_mask' in batch:
                    print(f"Attn Mask type: {type(batch['attention_mask'])}")
                continue  # Skip this batch

            optimizer.zero_grad(set_to_none=True)  # Use set_to_none for potential memory savings

            try:
                # Forward pass
                logits, spk_trains = model(input_ids, attention_mask=attention_mask)

                # Calculate Loss
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = input_ids[..., 1:].contiguous()
                loss_xent = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                loss_l1 = hparams['l1_lambda'] * torch.mean(torch.abs(spk_trains))  # L1 spike penalty
                total_loss = loss_xent + loss_l1

                # Backward pass and optimization
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Gradient clipping
                optimizer.step()
                scheduler.step()  # Update learning rate

                running_loss += loss_xent.item()
                running_l1 += loss_l1.item()
                steps += 1
            except Exception as e:
                print(f"\nError during train step: {e}")
                traceback.print_exc()
                continue  # Try to continue with next batch

            # Log Progress within Epoch
            if steps > 0 and (steps % log_interval == 0 or steps == len(train_loader)):
                current_time = time.time()
                elapsed = current_time - last_log_time
                batches_per_sec = log_interval / elapsed if elapsed > 0 else 0
                avg_loss = running_loss / steps
                avg_l1 = running_l1 / steps
                try:
                    perplexity = math.exp(avg_loss)
                except OverflowError:
                    perplexity = float('inf')  # Handle potential overflow
                batch_iterator.set_postfix({
                    'Step': f'{steps}/{len(train_loader)}',
                    'Avg Loss': f'{avg_loss:.4f}',
                    'Avg PPL': f'{perplexity:.2f}',
                    'Avg L1': f'{avg_l1:.6f}',
                    'LR': f'{scheduler.get_last_lr()[0]:.2e}',
                    'Batch/s': f'{batches_per_sec:.2f}'
                })
                last_log_time = time.time()

        # --- End of Training Epoch ---
        if steps == 0:
            print(f"Epoch {current_epoch_num} had no completed steps. Skipping validation and saving.")
            continue  # Skip to next epoch if no steps were successful

        avg_train_loss = running_loss / steps
        avg_train_l1 = running_l1 / steps
        try:
            train_perplexity = math.exp(avg_train_loss)
        except OverflowError:
            train_perplexity = float('inf')

        # --- Validation Phase ---
        model.eval()  # Set model to evaluation mode
        val_loss, val_l1, val_steps = 0.0, 0.0, 0
        val_batch_iterator = tqdm(val_loader, desc=f"Epoch {current_epoch_num}/{num_epochs} Validation", leave=False)
        with torch.no_grad():
            for batch in val_batch_iterator:
                try:
                    input_ids = batch['input_ids'].to(device, non_blocking=True)
                    attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                    logits, spk_trains = model(input_ids, attention_mask=attention_mask)
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = input_ids[..., 1:].contiguous()
                    batch_loss_xent = criterion(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                    batch_loss_l1 = hparams['l1_lambda'] * torch.mean(torch.abs(spk_trains))
                    val_loss += batch_loss_xent.item()
                    val_l1 += batch_loss_l1.item()
                    val_steps += 1
                except Exception as e:
                    print(f"\nError during validation step: {e}")
                    continue

        if val_steps == 0:
            print(f"Epoch {current_epoch_num} had no completed validation steps. Using NaN for validation metrics.")
            avg_val_loss, avg_val_l1, val_perplexity = float('nan'), float('nan'), float('nan')
        else:
            avg_val_loss = val_loss / val_steps
            avg_val_l1 = val_l1 / val_steps
            try:
                val_perplexity = math.exp(avg_val_loss)
            except (OverflowError, ValueError):
                val_perplexity = float('inf')

        epoch_duration = time.time() - epoch_start_time

        # --- Log Epoch Results ---
        print(f"\nEpoch {current_epoch_num}/{num_epochs} completed in {epoch_duration:.2f}s")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train PPL: {train_perplexity:.2f} | Train L1: {avg_train_l1:.6f}")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Val PPL:   {val_perplexity:.2f} | Val L1:   {avg_val_l1:.6f}")

        # --- Update Training History ---
        safe_train_ppl = train_perplexity if math.isfinite(train_perplexity) else None
        safe_val_ppl = val_perplexity if math.isfinite(val_perplexity) else None
        safe_avg_val_loss = avg_val_loss if math.isfinite(avg_val_loss) else None
        safe_avg_val_l1 = avg_val_l1 if math.isfinite(avg_val_l1) else None

        if current_epoch_num not in training_history['epoch']:
            training_history['epoch'].append(current_epoch_num)
            training_history['train_loss'].append(avg_train_loss)
            training_history['train_perplexity'].append(safe_train_ppl)
            training_history['train_l1_loss'].append(avg_train_l1)
            training_history['val_loss'].append(safe_avg_val_loss)
            training_history['val_perplexity'].append(safe_val_ppl)
            training_history['val_l1_loss'].append(safe_avg_val_l1)
        else:
            idx = training_history['epoch'].index(current_epoch_num)
            training_history['train_loss'][idx] = avg_train_loss
            training_history['train_perplexity'][idx] = safe_train_ppl
            training_history['train_l1_loss'][idx] = avg_train_l1
            training_history['val_loss'][idx] = safe_avg_val_loss
            training_history['val_perplexity'][idx] = safe_val_ppl
            training_history['val_l1_loss'][idx] = safe_avg_val_l1
            print(f"  Overwriting history for epoch {current_epoch_num}")

        # --- Checkpoint Saving ---
        is_best = False
        if math.isfinite(avg_val_loss) and avg_val_loss < best_val_loss:
            is_best = True
            best_val_loss = avg_val_loss
            print(f"  * New best validation loss found! Saving best model state to '{best_model_path}'")
            try:
                torch.save(model.state_dict(), best_model_path)
            except Exception as e:
                print(f"  Warning: Failed to save best model state: {e}")

        checkpoint_state = {
            'epoch': epoch,  # Save 0-indexed epoch number *completed*
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_loss': best_val_loss,  # Persist the best loss found so far
            'training_history': training_history,
            'hparams': hparams
        }
        save_checkpoint(checkpoint_state, checkpoint_path)

        try:
            serializable_history = copy.deepcopy(training_history)
            for key in serializable_history:
                serializable_history[key] = [
                    x if x is not None and math.isfinite(x) else None for x in serializable_history[key]
                ]
            with open(history_path, 'w') as f:
                json.dump(serializable_history, f, indent=2)
        except Exception as e:
            print(f"Warning: Could not save training history JSON: {e}")

    total_duration = time.time() - total_start_time
    total_epochs_trained = num_epochs - start_epoch
    print(f"\n--- Training Finished ({total_epochs_trained} Epochs Trained) ---")
    if total_epochs_trained > 0:
        print(f"Total training time: {total_duration/3600:.2f} hours")
    print(f"Best validation loss achieved: {best_val_loss:.4f}")

    # --- Plotting ---
    valid_epochs = [e for i, e in enumerate(training_history.get('epoch', []))
                    if training_history.get('val_loss', [])[i] is not None]
    valid_train_loss = [l for i, l in enumerate(training_history.get('train_loss', []))
                        if training_history.get('val_loss', [])[i] is not None]
    valid_val_loss = [l for l in training_history.get('val_loss', []) if l is not None]
    valid_train_ppl = [p for i, p in enumerate(training_history.get('train_perplexity', []))
                       if training_history.get('val_loss', [])[i] is not None and p is not None]
    valid_val_ppl = [p for p in training_history.get('val_perplexity', []) if p is not None]
    valid_train_l1 = [l1 for i, l1 in enumerate(training_history.get('train_l1_loss', []))
                      if training_history.get('val_loss', [])[i] is not None]
    valid_val_l1 = [l1 for l1 in training_history.get('val_l1_loss', []) if l1 is not None]

    if len(valid_epochs) != len(valid_val_loss):
        valid_epochs = valid_epochs[:len(valid_val_loss)]
    if len(valid_epochs) != len(valid_train_loss):
        valid_train_loss = valid_train_loss[:len(valid_epochs)]
    if len(valid_epochs) != len(valid_train_ppl):
        valid_train_ppl = valid_train_ppl[:len(valid_epochs)]
    if len(valid_epochs) != len(valid_val_ppl):
        valid_val_ppl = valid_val_ppl[:len(valid_epochs)]
    if len(valid_epochs) != len(valid_train_l1):
        valid_train_l1 = valid_train_l1[:len(valid_epochs)]
    if len(valid_epochs) != len(valid_val_l1):
        valid_val_l1 = valid_val_l1[:len(valid_epochs)]

    if valid_epochs:
        try:
            fig, axs = plt.subplots(1, 2, figsize=(16, 5))
            axs[0].plot(valid_epochs, valid_train_loss, 'o-', label="Train Loss")
            axs[0].plot(valid_epochs, valid_val_loss, 'x-', label="Val Loss")
            axs[0].set_xlabel("Epoch")
            axs[0].set_ylabel("Loss")
            axs[0].set_title("Loss")
            axs[0].legend()
            axs[0].grid(True)

            axs[1].plot(valid_epochs, valid_train_ppl, 'o-', label="Train PPL")
            axs[1].plot(valid_epochs, valid_val_ppl, 'x-', label="Val PPL")
            axs[1].set_xlabel("Epoch")
            axs[1].set_ylabel("Perplexity")
            axs[1].set_title("Perplexity")
            axs[1].legend()
            axs[1].grid(True)
            axs[1].set_yscale('log')
            plt.tight_layout()
            plot_path = os.path.join(output_dir, "loss_perplexity_curves.png")
            plt.savefig(plot_path)
            print(f"Loss/perplexity plot saved to {plot_path}")
            plt.show()

            plt.figure(figsize=(8, 5))
            plt.plot(valid_epochs, valid_train_l1, 'o-', label="Train L1")
            plt.plot(valid_epochs, valid_val_l1, 'x-', label="Val L1")
            plt.xlabel("Epoch")
            plt.ylabel("L1 Loss")
            plt.title("Spike L1 Regularization")
            plt.legend()
            plt.grid(True)
            plt.tight_layout()
            l1_plot_path = os.path.join(output_dir, "l1_loss_curve.png")
            plt.savefig(l1_plot_path)
            print(f"L1 loss plot saved to {l1_plot_path}")
            plt.show()
        except Exception as plot_err:
            print(f"Error generating plots: {plot_err}")
    else:
        print("No valid training history found to plot.")

    return training_history

# --------------------------------------------------------------------------------
# 7) Main Execution Block
# --------------------------------------------------------------------------------
if __name__ == "__main__":
    # --- Basic Setup ---
    set_seed(HPARAMS['seed'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    output_dir = HPARAMS['output_dir']

    # --- !!! IMPORTANT FOR COLAB PERSISTENCE !!! ---
    # Mount Google Drive before running this cell if using Colab.
    try:
        os.makedirs(output_dir, exist_ok=True)
        print(f"Output directory confirmed: '{output_dir}'")
    except OSError as e:
        print(f"CRITICAL Error creating output directory '{output_dir}': {e}")
        print("Please ensure the path is valid and accessible. Exiting.")
        exit()

    # --- Save Hyperparameters ---
    hparams_path = os.path.join(output_dir, HPARAMS['hparams_filename'])
    try:
        with open(hparams_path, "w") as f:
            json.dump(HPARAMS, f, indent=2)
        print(f"Hyperparameters saved to '{hparams_path}'")
    except Exception as e:
        print(f"Warning: Could not save hyperparameters: {e}")

    # --- Load and Tokenize Data ---
    print("\nLoading and preparing dataset...")
    try:
        raw_data = load_dataset("wikitext", "wikitext-2-raw-v1")
        raw_data = raw_data.filter(lambda x: x['text'] and x['text'].strip())
        if not raw_data['train']:
            print("CRITICAL Error: No valid training data found after filtering empty lines. Exiting.")
            exit()
    except Exception as e:
        print(f"Failed to load dataset: {e}. Exiting.")
        exit()

    print("Loading tokenizer...")
    try:
        tokenizer = GPT2Tokenizer.from_pretrained(HPARAMS['model_name'])
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            print(f"Set tokenizer pad_token to eos_token ({tokenizer.eos_token})")
    except Exception as e:
        print(f"Failed to load tokenizer '{HPARAMS['model_name']}': {e}. Exiting.")
        exit()

    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=HPARAMS['seq_length'])

    print("Tokenizing dataset (this might take a while)...")
    try:
        num_cpus = os.cpu_count()
        num_proc = min(4, num_cpus if num_cpus is not None else 1)
        print(f"Using {num_proc} processes for tokenization.")
        tokenized = raw_data.map(tokenize_function, batched=True, num_proc=num_proc, remove_columns=raw_data["train"].column_names)
        train_data = tokenized["train"]
        val_data = tokenized["validation"]
        train_data.set_format(type='torch', columns=['input_ids', 'attention_mask'])
        val_data.set_format(type='torch', columns=['input_ids', 'attention_mask'])
        if len(train_data) == 0:
            print("CRITICAL Error: Training data is empty after tokenization. Exiting.")
            exit()
    except Exception as e:
        print(f"Failed during dataset tokenization: {e}. Exiting.")
        traceback.print_exc()
        exit()

    try:
        pin_memory_flag = True if device.type == 'cuda' else False
        num_workers = min(2, num_cpus if num_cpus is not None else 1)
        print(f"Using {num_workers} workers for DataLoaders.")
        train_loader = DataLoader(train_data, batch_size=HPARAMS['batch_size'], shuffle=True,
                                  num_workers=num_workers, pin_memory=pin_memory_flag, drop_last=True)
        val_loader = DataLoader(val_data, batch_size=HPARAMS['batch_size'], shuffle=False,
                                num_workers=num_workers, pin_memory=pin_memory_flag)
        print(f"DataLoaders ready: Train batches={len(train_loader)}, Val batches={len(val_loader)}")
        if len(train_loader) == 0:
            print("CRITICAL Error: Training DataLoader has zero batches. Check batch size and dataset size. Exiting.")
            exit()
    except Exception as e:
        print(f"Failed to create DataLoaders: {e}. Exiting.")
        exit()

    print("\nInstantiating model, optimizer, and scheduler...")
    try:
        model = DLPFCTransformer(HPARAMS).to(device)
        optimizer = AdamW(model.parameters(), lr=HPARAMS['learning_rate'], weight_decay=HPARAMS['weight_decay'])
        full_total_steps = len(train_loader) * HPARAMS['num_epochs']
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=HPARAMS['warmup_steps'],
                                                    num_training_steps=full_total_steps)
    except Exception as e:
        print(f"Failed to instantiate model/optimizer/scheduler: {e}. Exiting.")
        traceback.print_exc()
        exit()

    print("\n--- Model Architecture ---")
    try:
        summary_batch_size = HPARAMS['batch_size']
        sample_input_ids = torch.zeros((summary_batch_size, HPARAMS['seq_length']), dtype=torch.long).to(device)
        sample_attention_mask = torch.ones((summary_batch_size, HPARAMS['seq_length']), dtype=torch.long).to(device)
        summary(model, input_data=(sample_input_ids, sample_attention_mask), depth=5,
                col_names=["input_size", "output_size", "num_params", "mult_adds"], row_settings=["var_names"])
    except ImportError:
        print("torchinfo not found. Install (`pip install torchinfo`) for detailed summary.")
        print(model)
    except Exception as e:
        print(f"Could not generate model summary: {e}\n{model}")
    try:
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Total Trainable Parameters: {num_params:,}")
    except Exception:
        pass

    checkpoint_path = os.path.join(output_dir, HPARAMS['checkpoint_filename'])
    start_epoch, best_val_loss, training_history = load_checkpoint(checkpoint_path, model, optimizer, scheduler, device)
    if not isinstance(training_history, dict) or not all(k in training_history for k in initialize_history().keys()):
        print("Loaded training history is invalid or incomplete. Reinitializing.")
        training_history = initialize_history()

    try:
        training_history = train_model(
            model, train_loader, val_loader, optimizer, scheduler, device,
            HPARAMS, start_epoch, best_val_loss, training_history
        )
    except Exception as train_err:
        print("\n--- CRITICAL ERROR DURING TRAINING ---")
        print(f"{train_err}")
        traceback.print_exc()
        print("Attempting to save final state before exiting...")

    print("\nSaving final model components...")
    final_model_path = os.path.join(output_dir, HPARAMS['final_model_filename'])
    try:
        torch.save(model.state_dict(), final_model_path)
        print(f"Final model state_dict saved to '{final_model_path}'")
    except Exception as e:
        print(f"Warning: Could not save final model state: {e}")

    try:
        tokenizer.save_pretrained(output_dir)
        print(f"Tokenizer saved to '{output_dir}'")
    except Exception as e:
        print(f"Warning: Could not save tokenizer: {e}")

    history_path = os.path.join(output_dir, HPARAMS['history_filename'])
    try:
        serializable_history = copy.deepcopy(training_history)
        for key in serializable_history:
            if isinstance(serializable_history[key], list):
                serializable_history[key] = [
                    x if x is not None and isinstance(x, (int, float)) and math.isfinite(x) else None
                    for x in serializable_history[key]
                ]
        with open(history_path, 'w') as f:
            json.dump(serializable_history, f, indent=2)
        print(f"Final training history saved to '{history_path}'")
    except Exception as e:
        print(f"Warning: Could not save final training history JSON: {e}")

    print("\n--- Script Execution Complete ---")
    print(f"Find results, checkpoints, and plots in: {output_dir}")
