| """ |
| Training utilities for LMCODE (Language Model with Memory CODE). |
| |
| Implements memory-aware training with: |
| - Experience replay from long-term memory |
| - Memory consolidation |
| - Gradient clipping for memory stability |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import Dataset, DataLoader |
| import numpy as np |
| from typing import Optional, Dict, List, Tuple |
| from model_architecture import LMCODE, LMCODEConfig |
| import math |
|
|
|
|
| class MemoryDataset(Dataset): |
| """ |
| Dataset that can sample from both current data and long-term memory. |
| |
| Implements experience replay by mixing current training examples |
| with retrieved memories from the model's long-term memory. |
| """ |
| |
| def __init__(self, data: List[Dict], memory_sample_ratio: float = 0.2): |
| """ |
| Args: |
| data: List of training examples (dicts with 'input_ids', 'labels') |
| memory_sample_ratio: Fraction of batch to sample from memory |
| """ |
| self.data = data |
| self.memory_sample_ratio = memory_sample_ratio |
| |
| def __len__(self) -> int: |
| return len(self.data) |
| |
| def __getitem__(self, idx: int) -> Dict: |
| return self.data[idx] |
| |
| def sample_with_memory(self, model: LMCODE, batch_size: int) -> Dict[str, torch.Tensor]: |
| """ |
| Sample a batch mixing current data and memory samples. |
| |
| Args: |
| model: LMCODE model to query memory from |
| batch_size: Total batch size |
| |
| Returns: |
| Batch dictionary with mixed data |
| """ |
| |
| memory_batch_size = int(batch_size * self.memory_sample_ratio) |
| current_batch_size = batch_size - memory_batch_size |
| |
| |
| current_indices = torch.randint(0, len(self.data), (current_batch_size,)) |
| current_batch = [self.data[i] for i in current_indices.tolist()] |
| |
| |
| current_batch_padded = self._pad_batch(current_batch) |
| |
| |
| memory_batch_padded = None |
| if memory_batch_size > 0 and hasattr(model, 'long_term_memory_size'): |
| |
| |
| pass |
| |
| return current_batch_padded |
| |
| def _pad_batch(self, batch: List[Dict]) -> Dict[str, torch.Tensor]: |
| """Pad a batch of sequences to the same length.""" |
| max_len = max(item['input_ids'].shape[-1] for item in batch) |
| |
| padded_inputs = [] |
| padded_labels = [] |
| |
| for item in batch: |
| input_ids = item['input_ids'].squeeze(0) |
| labels = item.get('labels', input_ids.clone()) |
| |
| |
| pad_len = max_len - input_ids.shape[-1] |
| if pad_len > 0: |
| input_ids = torch.cat([input_ids, torch.zeros(pad_len, dtype=input_ids.dtype)]) |
| |
| |
| if labels.shape[-1] < max_len: |
| pad_len = max_len - labels.shape[-1] |
| labels = torch.cat([labels, torch.full((pad_len,), -100, dtype=labels.dtype)]) |
| |
| padded_inputs.append(input_ids) |
| padded_labels.append(labels) |
| |
| return { |
| 'input_ids': torch.stack(padded_inputs), |
| 'labels': torch.stack(padded_labels) |
| } |
|
|
|
|
| class MemoryAwareTrainer: |
| """ |
| Trainer for LMCODE with memory-aware training. |
| |
| Features: |
| - Memory consolidation scheduling |
| - Gradient clipping for memory parameters |
| - Experience replay |
| - Memory importance updates |
| """ |
| |
| def __init__(self, model: LMCODE, config: Dict): |
| """ |
| Initialize trainer. |
| |
| Args: |
| model: LMCODE model to train |
| config: Training configuration dictionary |
| """ |
| self.model = model |
| self.config = config |
| |
| |
| self.lr = config.get('learning_rate', 1e-4) |
| self.weight_decay = config.get('weight_decay', 0.01) |
| self.gradient_clip = config.get('gradient_clip', 1.0) |
| self.memory_consolidation_interval = config.get('memory_consolidation_interval', 1000) |
| self.warmup_steps = config.get('warmup_steps', 1000) |
| |
| |
| self.optimizer = self._create_optimizer() |
| |
| |
| self.scheduler = self._create_scheduler() |
| |
| |
| self.global_step = 0 |
| self.best_loss = float('inf') |
| |
| |
| self.loss_history = [] |
| self.memory_stats = [] |
| |
| def _create_optimizer(self) -> optim.Optimizer: |
| """Create optimizer with parameter groups.""" |
| |
| memory_params = [] |
| model_params = [] |
| |
| for name, param in self.model.named_parameters(): |
| if 'memory' in name: |
| memory_params.append(param) |
| else: |
| model_params.append(param) |
| |
| |
| param_groups = [ |
| {'params': model_params, 'lr': self.lr, 'weight_decay': self.weight_decay}, |
| {'params': memory_params, 'lr': self.lr * 2, 'weight_decay': 0.0} |
| ] |
| |
| return optim.AdamW(param_groups) |
| |
| def _create_scheduler(self): |
| """Create learning rate scheduler with warmup.""" |
| def lr_lambda(current_step): |
| if current_step < self.warmup_steps: |
| return float(current_step) / float(max(1, self.warmup_steps)) |
| return max( |
| 0.0, |
| float(self.config.get('total_steps', 10000) - current_step) / |
| float(max(1, self.config.get('total_steps', 10000) - self.warmup_steps)) |
| ) |
| |
| return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) |
| |
| def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: |
| """ |
| Perform a single training step. |
| |
| Args: |
| batch: Batch of training data |
| |
| Returns: |
| Dictionary with loss and memory statistics |
| """ |
| self.model.train() |
| |
| |
| device = next(self.model.parameters()).device |
| input_ids = batch['input_ids'].to(device) |
| labels = batch['labels'].to(device) |
| |
| |
| |
| store_long_term = (self.global_step % 10 == 0) |
| |
| |
| outputs = self.model( |
| input_ids=input_ids, |
| labels=labels, |
| use_long_term_memory=True, |
| store_long_term=store_long_term |
| ) |
| |
| loss = outputs['loss'] |
| |
| |
| self.optimizer.zero_grad() |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip) |
| |
| |
| self._clip_memory_gradients() |
| |
| |
| self.optimizer.step() |
| self.scheduler.step() |
| |
| |
| if self.global_step % 50 == 0: |
| self._update_memory_importance(outputs) |
| |
| |
| if self.global_step % self.memory_consolidation_interval == 0: |
| self._consolidate_memories() |
| |
| |
| stats = { |
| 'loss': loss.item(), |
| 'learning_rate': self.scheduler.get_last_lr()[0], |
| 'global_step': self.global_step, |
| 'store_long_term': store_long_term |
| } |
| |
| |
| memory_stats = self._get_memory_stats() |
| stats.update(memory_stats) |
| |
| self.loss_history.append(loss.item()) |
| self.memory_stats.append(memory_stats) |
| self.global_step += 1 |
| |
| return stats |
| |
| def _clip_memory_gradients(self): |
| """Apply special gradient clipping for memory parameters.""" |
| for name, param in self.model.named_parameters(): |
| if 'memory' in name and param.grad is not None: |
| |
| torch.nn.utils.clip_grad_norm_([param], max_norm=0.5) |
| |
| def _update_memory_importance(self, outputs: Dict): |
| """ |
| Update memory importance based on usage in forward pass. |
| |
| Importance increases when memories are retrieved with high weight. |
| """ |
| |
| for layer_output in outputs.get('long_term_outputs', []): |
| if layer_output is None: |
| continue |
| |
| |
| retrieval_weights = layer_output.get('retrieval_weights') |
| if retrieval_weights is not None: |
| |
| |
| |
| pass |
| |
| def _consolidate_memories(self): |
| """Consolidate long-term memories across all layers.""" |
| for layer in self.model.layers: |
| layer.long_term_memory.consolidate_memories() |
| |
| def _get_memory_stats(self) -> Dict[str, float]: |
| """Get statistics about memory usage.""" |
| stats = {} |
| |
| for i, layer in enumerate(self.model.layers): |
| |
| st_memory = layer.short_term_memory.memory |
| stats[f'layer_{i}_st_memory_mean'] = st_memory.mean().item() |
| stats[f'layer_{i}_st_memory_std'] = st_memory.std().item() |
| |
| |
| lt_keys = layer.long_term_memory.memory_keys |
| lt_values = layer.long_term_memory.memory_values |
| lt_importance = layer.long_term_memory.memory_importance |
| |
| stats[f'layer_{i}_lt_keys_mean'] = lt_keys.mean().item() |
| stats[f'layer_{i}_lt_importance_mean'] = torch.sigmoid(lt_importance).mean().item() |
| |
| |
| active_count = (torch.sigmoid(lt_importance) > 0.1).sum().item() |
| stats[f'layer_{i}_lt_active_count'] = active_count |
| |
| return stats |
| |
| def train(self, train_dataset: MemoryDataset, |
| num_epochs: int, |
| batch_size: int = 32, |
| eval_dataset: Optional[MemoryDataset] = None) -> Dict: |
| """ |
| Train the model. |
| |
| Args: |
| train_dataset: Training dataset |
| num_epochs: Number of training epochs |
| batch_size: Batch size |
| eval_dataset: Optional evaluation dataset |
| |
| Returns: |
| Training history |
| """ |
| history = { |
| 'train_loss': [], |
| 'eval_loss': [], |
| 'memory_stats': [] |
| } |
| |
| for epoch in range(num_epochs): |
| self.model.train() |
| epoch_loss = 0 |
| num_batches = 0 |
| |
| |
| dataloader = DataLoader( |
| train_dataset, |
| batch_size=batch_size, |
| shuffle=True |
| ) |
| |
| for batch_idx, batch in enumerate(dataloader): |
| |
| stats = self.train_step(batch) |
| |
| epoch_loss += stats['loss'] |
| num_batches += 1 |
| |
| |
| if batch_idx % 100 == 0: |
| print(f"Epoch {epoch+1}/{num_epochs}, " |
| f"Batch {batch_idx}/{len(dataloader)}, " |
| f"Loss: {stats['loss']:.4f}") |
| |
| |
| avg_epoch_loss = epoch_loss / num_batches |
| history['train_loss'].append(avg_epoch_loss) |
| |
| |
| if eval_dataset is not None: |
| eval_loss = self.evaluate(eval_dataset) |
| history['eval_loss'].append(eval_loss) |
| print(f"Epoch {epoch+1} - Train Loss: {avg_epoch_loss:.4f}, " |
| f"Eval Loss: {eval_loss:.4f}") |
| else: |
| print(f"Epoch {epoch+1} - Train Loss: {avg_epoch_loss:.4f}") |
| |
| |
| if avg_epoch_loss < self.best_loss: |
| self.best_loss = avg_epoch_loss |
| self.save_checkpoint('best_model.pt') |
| |
| return history |
| |
| def evaluate(self, dataset: MemoryDataset) -> float: |
| """ |
| Evaluate the model on a dataset. |
| |
| Args: |
| dataset: Evaluation dataset |
| |
| Returns: |
| Average loss |
| """ |
| self.model.eval() |
| total_loss = 0 |
| num_batches = 0 |
| |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=False) |
| |
| with torch.no_grad(): |
| for batch in dataloader: |
| |
| device = next(self.model.parameters()).device |
| input_ids = batch['input_ids'].to(device) |
| labels = batch['labels'].to(device) |
| |
| |
| outputs = self.model( |
| input_ids=input_ids, |
| labels=labels, |
| use_long_term_memory=True, |
| store_long_term=False |
| ) |
| |
| total_loss += outputs['loss'].item() |
| num_batches += 1 |
| |
| return total_loss / num_batches |
| |
| def save_checkpoint(self, path: str): |
| """ |
| Save model checkpoint. |
| |
| Args: |
| path: Path to save checkpoint |
| """ |
| checkpoint = { |
| 'model_state_dict': self.model.state_dict(), |
| 'optimizer_state_dict': self.optimizer.state_dict(), |
| 'scheduler_state_dict': self.scheduler.state_dict(), |
| 'global_step': self.global_step, |
| 'best_loss': self.best_loss, |
| 'config': self.config, |
| 'loss_history': self.loss_history |
| } |
| |
| torch.save(checkpoint, path) |
| print(f"Checkpoint saved to {path}") |
| |
| def load_checkpoint(self, path: str): |
| """ |
| Load model checkpoint. |
| |
| Args: |
| path: Path to checkpoint file |
| """ |
| checkpoint = torch.load(path, map_location='cpu') |
| |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| self.global_step = checkpoint['global_step'] |
| self.best_loss = checkpoint['best_loss'] |
| self.loss_history = checkpoint.get('loss_history', []) |
| |
| print(f"Checkpoint loaded from {path}") |
|
|
|
|
| def create_synthetic_dataset(num_samples: int = 1000, |
| seq_len: int = 50, |
| vocab_size: int = 50257) -> List[Dict]: |
| """ |
| Create a synthetic dataset for testing. |
| |
| Args: |
| num_samples: Number of samples |
| seq_len: Sequence length |
| vocab_size: Vocabulary size |
| |
| Returns: |
| List of training examples |
| """ |
| dataset = [] |
| |
| for _ in range(num_samples): |
| |
| input_ids = torch.randint(0, vocab_size, (seq_len,)) |
| |
| |
| labels = torch.cat([ |
| input_ids[1:], |
| torch.zeros(1, dtype=input_ids.dtype) |
| ], dim=0) |
| |
| dataset.append({ |
| 'input_ids': input_ids, |
| 'labels': labels |
| }) |
| |
| return dataset |
|
|
|
|
| if __name__ == '__main__': |
| |
| config = LMCODEConfig( |
| vocab_size=50257, |
| hidden_size=256, |
| num_layers=4, |
| num_heads=4, |
| short_term_memory_size=256, |
| long_term_memory_slots=1000 |
| ) |
| |
| model = LMCODE(config) |
| |
| |
| train_data = create_synthetic_dataset(num_samples=100, seq_len=32) |
| train_dataset = MemoryDataset(train_data, memory_sample_ratio=0.2) |
| |
| |
| trainer_config = { |
| 'learning_rate': 1e-4, |
| 'weight_decay': 0.01, |
| 'gradient_clip': 1.0, |
| 'memory_consolidation_interval': 50, |
| 'warmup_steps': 10, |
| 'total_steps': 1000 |
| } |
| |
| trainer = MemoryAwareTrainer(model, trainer_config) |
| |
| |
| print("Starting training...") |
| history = trainer.train(train_dataset, num_epochs=2, batch_size=8) |
| |
| |
| trainer.save_checkpoint('lm_memory_model.pt') |
| |
| print("Training complete!") |
| print(f"Final loss: {history['train_loss'][-1]:.4f}") |