""" Training utilities for LMCODE (Language Model with Memory CODE). Implements memory-aware training with: - Experience replay from long-term memory - Memory consolidation - Gradient clipping for memory stability """ import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader import numpy as np from typing import Optional, Dict, List, Tuple from model_architecture import LMCODE, LMCODEConfig import math class MemoryDataset(Dataset): """ Dataset that can sample from both current data and long-term memory. Implements experience replay by mixing current training examples with retrieved memories from the model's long-term memory. """ def __init__(self, data: List[Dict], memory_sample_ratio: float = 0.2): """ Args: data: List of training examples (dicts with 'input_ids', 'labels') memory_sample_ratio: Fraction of batch to sample from memory """ self.data = data self.memory_sample_ratio = memory_sample_ratio def __len__(self) -> int: return len(self.data) def __getitem__(self, idx: int) -> Dict: return self.data[idx] def sample_with_memory(self, model: LMCODE, batch_size: int) -> Dict[str, torch.Tensor]: """ Sample a batch mixing current data and memory samples. Args: model: LMCODE model to query memory from batch_size: Total batch size Returns: Batch dictionary with mixed data """ # Sample from current data memory_batch_size = int(batch_size * self.memory_sample_ratio) current_batch_size = batch_size - memory_batch_size # Sample current data current_indices = torch.randint(0, len(self.data), (current_batch_size,)) current_batch = [self.data[i] for i in current_indices.tolist()] # Pad sequences current_batch_padded = self._pad_batch(current_batch) # Sample from long-term memory (if available) memory_batch_padded = None if memory_batch_size > 0 and hasattr(model, 'long_term_memory_size'): # In practice, retrieve from model's long-term memory # For now, return None pass return current_batch_padded def _pad_batch(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: """Pad a batch of sequences to the same length.""" max_len = max(item['input_ids'].shape[-1] for item in batch) padded_inputs = [] padded_labels = [] for item in batch: input_ids = item['input_ids'].squeeze(0) labels = item.get('labels', input_ids.clone()) # Pad input pad_len = max_len - input_ids.shape[-1] if pad_len > 0: input_ids = torch.cat([input_ids, torch.zeros(pad_len, dtype=input_ids.dtype)]) # Pad labels if labels.shape[-1] < max_len: pad_len = max_len - labels.shape[-1] labels = torch.cat([labels, torch.full((pad_len,), -100, dtype=labels.dtype)]) padded_inputs.append(input_ids) padded_labels.append(labels) return { 'input_ids': torch.stack(padded_inputs), 'labels': torch.stack(padded_labels) } class MemoryAwareTrainer: """ Trainer for LMCODE with memory-aware training. Features: - Memory consolidation scheduling - Gradient clipping for memory parameters - Experience replay - Memory importance updates """ def __init__(self, model: LMCODE, config: Dict): """ Initialize trainer. Args: model: LMCODE model to train config: Training configuration dictionary """ self.model = model self.config = config # Training parameters self.lr = config.get('learning_rate', 1e-4) self.weight_decay = config.get('weight_decay', 0.01) self.gradient_clip = config.get('gradient_clip', 1.0) self.memory_consolidation_interval = config.get('memory_consolidation_interval', 1000) self.warmup_steps = config.get('warmup_steps', 1000) # Optimizer with separate learning rates for memory parameters self.optimizer = self._create_optimizer() # Learning rate scheduler self.scheduler = self._create_scheduler() # Training state self.global_step = 0 self.best_loss = float('inf') # Loss tracking self.loss_history = [] self.memory_stats = [] def _create_optimizer(self) -> optim.Optimizer: """Create optimizer with parameter groups.""" # Separate memory parameters from model parameters memory_params = [] model_params = [] for name, param in self.model.named_parameters(): if 'memory' in name: memory_params.append(param) else: model_params.append(param) # Higher learning rate for memory parameters param_groups = [ {'params': model_params, 'lr': self.lr, 'weight_decay': self.weight_decay}, {'params': memory_params, 'lr': self.lr * 2, 'weight_decay': 0.0} # No weight decay for memory ] return optim.AdamW(param_groups) def _create_scheduler(self): """Create learning rate scheduler with warmup.""" def lr_lambda(current_step): if current_step < self.warmup_steps: return float(current_step) / float(max(1, self.warmup_steps)) return max( 0.0, float(self.config.get('total_steps', 10000) - current_step) / float(max(1, self.config.get('total_steps', 10000) - self.warmup_steps)) ) return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """ Perform a single training step. Args: batch: Batch of training data Returns: Dictionary with loss and memory statistics """ self.model.train() # Move batch to device device = next(self.model.parameters()).device input_ids = batch['input_ids'].to(device) labels = batch['labels'].to(device) # Determine whether to store in long-term memory # Store periodically (e.g., every 10 steps) store_long_term = (self.global_step % 10 == 0) # Forward pass outputs = self.model( input_ids=input_ids, labels=labels, use_long_term_memory=True, store_long_term=store_long_term ) loss = outputs['loss'] # Backward pass self.optimizer.zero_grad() loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip) # Memory-specific gradient clipping self._clip_memory_gradients() # Optimizer step self.optimizer.step() self.scheduler.step() # Update memory importance if self.global_step % 50 == 0: self._update_memory_importance(outputs) # Consolidate memories periodically if self.global_step % self.memory_consolidation_interval == 0: self._consolidate_memories() # Track statistics stats = { 'loss': loss.item(), 'learning_rate': self.scheduler.get_last_lr()[0], 'global_step': self.global_step, 'store_long_term': store_long_term } # Add memory statistics memory_stats = self._get_memory_stats() stats.update(memory_stats) self.loss_history.append(loss.item()) self.memory_stats.append(memory_stats) self.global_step += 1 return stats def _clip_memory_gradients(self): """Apply special gradient clipping for memory parameters.""" for name, param in self.model.named_parameters(): if 'memory' in name and param.grad is not None: # More aggressive clipping for memory parameters torch.nn.utils.clip_grad_norm_([param], max_norm=0.5) def _update_memory_importance(self, outputs: Dict): """ Update memory importance based on usage in forward pass. Importance increases when memories are retrieved with high weight. """ # Iterate through layers for layer_output in outputs.get('long_term_outputs', []): if layer_output is None: continue # Get retrieval weights retrieval_weights = layer_output.get('retrieval_weights') if retrieval_weights is not None: # Update importance based on average retrieval weight # This is a simplified version - in practice, you'd need # to track which specific memories were retrieved pass def _consolidate_memories(self): """Consolidate long-term memories across all layers.""" for layer in self.model.layers: layer.long_term_memory.consolidate_memories() def _get_memory_stats(self) -> Dict[str, float]: """Get statistics about memory usage.""" stats = {} for i, layer in enumerate(self.model.layers): # Short-term memory statistics st_memory = layer.short_term_memory.memory stats[f'layer_{i}_st_memory_mean'] = st_memory.mean().item() stats[f'layer_{i}_st_memory_std'] = st_memory.std().item() # Long-term memory statistics lt_keys = layer.long_term_memory.memory_keys lt_values = layer.long_term_memory.memory_values lt_importance = layer.long_term_memory.memory_importance stats[f'layer_{i}_lt_keys_mean'] = lt_keys.mean().item() stats[f'layer_{i}_lt_importance_mean'] = torch.sigmoid(lt_importance).mean().item() # Count active memories active_count = (torch.sigmoid(lt_importance) > 0.1).sum().item() stats[f'layer_{i}_lt_active_count'] = active_count return stats def train(self, train_dataset: MemoryDataset, num_epochs: int, batch_size: int = 32, eval_dataset: Optional[MemoryDataset] = None) -> Dict: """ Train the model. Args: train_dataset: Training dataset num_epochs: Number of training epochs batch_size: Batch size eval_dataset: Optional evaluation dataset Returns: Training history """ history = { 'train_loss': [], 'eval_loss': [], 'memory_stats': [] } for epoch in range(num_epochs): self.model.train() epoch_loss = 0 num_batches = 0 # Create data loader dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True ) for batch_idx, batch in enumerate(dataloader): # Perform training step stats = self.train_step(batch) epoch_loss += stats['loss'] num_batches += 1 # Log progress if batch_idx % 100 == 0: print(f"Epoch {epoch+1}/{num_epochs}, " f"Batch {batch_idx}/{len(dataloader)}, " f"Loss: {stats['loss']:.4f}") # Average epoch loss avg_epoch_loss = epoch_loss / num_batches history['train_loss'].append(avg_epoch_loss) # Evaluate if eval_dataset is not None: eval_loss = self.evaluate(eval_dataset) history['eval_loss'].append(eval_loss) print(f"Epoch {epoch+1} - Train Loss: {avg_epoch_loss:.4f}, " f"Eval Loss: {eval_loss:.4f}") else: print(f"Epoch {epoch+1} - Train Loss: {avg_epoch_loss:.4f}") # Save best model if avg_epoch_loss < self.best_loss: self.best_loss = avg_epoch_loss self.save_checkpoint('best_model.pt') return history def evaluate(self, dataset: MemoryDataset) -> float: """ Evaluate the model on a dataset. Args: dataset: Evaluation dataset Returns: Average loss """ self.model.eval() total_loss = 0 num_batches = 0 dataloader = DataLoader(dataset, batch_size=32, shuffle=False) with torch.no_grad(): for batch in dataloader: # Move to device device = next(self.model.parameters()).device input_ids = batch['input_ids'].to(device) labels = batch['labels'].to(device) # Forward pass (no memory storage during eval) outputs = self.model( input_ids=input_ids, labels=labels, use_long_term_memory=True, store_long_term=False ) total_loss += outputs['loss'].item() num_batches += 1 return total_loss / num_batches def save_checkpoint(self, path: str): """ Save model checkpoint. Args: path: Path to save checkpoint """ checkpoint = { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict(), 'global_step': self.global_step, 'best_loss': self.best_loss, 'config': self.config, 'loss_history': self.loss_history } torch.save(checkpoint, path) print(f"Checkpoint saved to {path}") def load_checkpoint(self, path: str): """ Load model checkpoint. Args: path: Path to checkpoint file """ checkpoint = torch.load(path, map_location='cpu') self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.global_step = checkpoint['global_step'] self.best_loss = checkpoint['best_loss'] self.loss_history = checkpoint.get('loss_history', []) print(f"Checkpoint loaded from {path}") def create_synthetic_dataset(num_samples: int = 1000, seq_len: int = 50, vocab_size: int = 50257) -> List[Dict]: """ Create a synthetic dataset for testing. Args: num_samples: Number of samples seq_len: Sequence length vocab_size: Vocabulary size Returns: List of training examples """ dataset = [] for _ in range(num_samples): # Generate random sequence input_ids = torch.randint(0, vocab_size, (seq_len,)) # Create labels (shifted by 1 for next-token prediction) labels = torch.cat([ input_ids[1:], torch.zeros(1, dtype=input_ids.dtype) ], dim=0) dataset.append({ 'input_ids': input_ids, 'labels': labels }) return dataset if __name__ == '__main__': # Create model config = LMCODEConfig( vocab_size=50257, hidden_size=256, # Smaller for testing num_layers=4, num_heads=4, short_term_memory_size=256, long_term_memory_slots=1000 ) model = LMCODE(config) # Create synthetic dataset train_data = create_synthetic_dataset(num_samples=100, seq_len=32) train_dataset = MemoryDataset(train_data, memory_sample_ratio=0.2) # Create trainer trainer_config = { 'learning_rate': 1e-4, 'weight_decay': 0.01, 'gradient_clip': 1.0, 'memory_consolidation_interval': 50, 'warmup_steps': 10, 'total_steps': 1000 } trainer = MemoryAwareTrainer(model, trainer_config) # Train for 2 epochs print("Starting training...") history = trainer.train(train_dataset, num_epochs=2, batch_size=8) # Save model trainer.save_checkpoint('lm_memory_model.pt') print("Training complete!") print(f"Final loss: {history['train_loss'][-1]:.4f}")