lm_memory_code / training.py
userkuku's picture
Upload training.py
3e15a02 verified
"""
Training utilities for LMCODE (Language Model with Memory CODE).
Implements memory-aware training with:
- Experience replay from long-term memory
- Memory consolidation
- Gradient clipping for memory stability
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import Optional, Dict, List, Tuple
from model_architecture import LMCODE, LMCODEConfig
import math
class MemoryDataset(Dataset):
"""
Dataset that can sample from both current data and long-term memory.
Implements experience replay by mixing current training examples
with retrieved memories from the model's long-term memory.
"""
def __init__(self, data: List[Dict], memory_sample_ratio: float = 0.2):
"""
Args:
data: List of training examples (dicts with 'input_ids', 'labels')
memory_sample_ratio: Fraction of batch to sample from memory
"""
self.data = data
self.memory_sample_ratio = memory_sample_ratio
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict:
return self.data[idx]
def sample_with_memory(self, model: LMCODE, batch_size: int) -> Dict[str, torch.Tensor]:
"""
Sample a batch mixing current data and memory samples.
Args:
model: LMCODE model to query memory from
batch_size: Total batch size
Returns:
Batch dictionary with mixed data
"""
# Sample from current data
memory_batch_size = int(batch_size * self.memory_sample_ratio)
current_batch_size = batch_size - memory_batch_size
# Sample current data
current_indices = torch.randint(0, len(self.data), (current_batch_size,))
current_batch = [self.data[i] for i in current_indices.tolist()]
# Pad sequences
current_batch_padded = self._pad_batch(current_batch)
# Sample from long-term memory (if available)
memory_batch_padded = None
if memory_batch_size > 0 and hasattr(model, 'long_term_memory_size'):
# In practice, retrieve from model's long-term memory
# For now, return None
pass
return current_batch_padded
def _pad_batch(self, batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Pad a batch of sequences to the same length."""
max_len = max(item['input_ids'].shape[-1] for item in batch)
padded_inputs = []
padded_labels = []
for item in batch:
input_ids = item['input_ids'].squeeze(0)
labels = item.get('labels', input_ids.clone())
# Pad input
pad_len = max_len - input_ids.shape[-1]
if pad_len > 0:
input_ids = torch.cat([input_ids, torch.zeros(pad_len, dtype=input_ids.dtype)])
# Pad labels
if labels.shape[-1] < max_len:
pad_len = max_len - labels.shape[-1]
labels = torch.cat([labels, torch.full((pad_len,), -100, dtype=labels.dtype)])
padded_inputs.append(input_ids)
padded_labels.append(labels)
return {
'input_ids': torch.stack(padded_inputs),
'labels': torch.stack(padded_labels)
}
class MemoryAwareTrainer:
"""
Trainer for LMCODE with memory-aware training.
Features:
- Memory consolidation scheduling
- Gradient clipping for memory parameters
- Experience replay
- Memory importance updates
"""
def __init__(self, model: LMCODE, config: Dict):
"""
Initialize trainer.
Args:
model: LMCODE model to train
config: Training configuration dictionary
"""
self.model = model
self.config = config
# Training parameters
self.lr = config.get('learning_rate', 1e-4)
self.weight_decay = config.get('weight_decay', 0.01)
self.gradient_clip = config.get('gradient_clip', 1.0)
self.memory_consolidation_interval = config.get('memory_consolidation_interval', 1000)
self.warmup_steps = config.get('warmup_steps', 1000)
# Optimizer with separate learning rates for memory parameters
self.optimizer = self._create_optimizer()
# Learning rate scheduler
self.scheduler = self._create_scheduler()
# Training state
self.global_step = 0
self.best_loss = float('inf')
# Loss tracking
self.loss_history = []
self.memory_stats = []
def _create_optimizer(self) -> optim.Optimizer:
"""Create optimizer with parameter groups."""
# Separate memory parameters from model parameters
memory_params = []
model_params = []
for name, param in self.model.named_parameters():
if 'memory' in name:
memory_params.append(param)
else:
model_params.append(param)
# Higher learning rate for memory parameters
param_groups = [
{'params': model_params, 'lr': self.lr, 'weight_decay': self.weight_decay},
{'params': memory_params, 'lr': self.lr * 2, 'weight_decay': 0.0} # No weight decay for memory
]
return optim.AdamW(param_groups)
def _create_scheduler(self):
"""Create learning rate scheduler with warmup."""
def lr_lambda(current_step):
if current_step < self.warmup_steps:
return float(current_step) / float(max(1, self.warmup_steps))
return max(
0.0,
float(self.config.get('total_steps', 10000) - current_step) /
float(max(1, self.config.get('total_steps', 10000) - self.warmup_steps))
)
return optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda)
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""
Perform a single training step.
Args:
batch: Batch of training data
Returns:
Dictionary with loss and memory statistics
"""
self.model.train()
# Move batch to device
device = next(self.model.parameters()).device
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
# Determine whether to store in long-term memory
# Store periodically (e.g., every 10 steps)
store_long_term = (self.global_step % 10 == 0)
# Forward pass
outputs = self.model(
input_ids=input_ids,
labels=labels,
use_long_term_memory=True,
store_long_term=store_long_term
)
loss = outputs['loss']
# Backward pass
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip)
# Memory-specific gradient clipping
self._clip_memory_gradients()
# Optimizer step
self.optimizer.step()
self.scheduler.step()
# Update memory importance
if self.global_step % 50 == 0:
self._update_memory_importance(outputs)
# Consolidate memories periodically
if self.global_step % self.memory_consolidation_interval == 0:
self._consolidate_memories()
# Track statistics
stats = {
'loss': loss.item(),
'learning_rate': self.scheduler.get_last_lr()[0],
'global_step': self.global_step,
'store_long_term': store_long_term
}
# Add memory statistics
memory_stats = self._get_memory_stats()
stats.update(memory_stats)
self.loss_history.append(loss.item())
self.memory_stats.append(memory_stats)
self.global_step += 1
return stats
def _clip_memory_gradients(self):
"""Apply special gradient clipping for memory parameters."""
for name, param in self.model.named_parameters():
if 'memory' in name and param.grad is not None:
# More aggressive clipping for memory parameters
torch.nn.utils.clip_grad_norm_([param], max_norm=0.5)
def _update_memory_importance(self, outputs: Dict):
"""
Update memory importance based on usage in forward pass.
Importance increases when memories are retrieved with high weight.
"""
# Iterate through layers
for layer_output in outputs.get('long_term_outputs', []):
if layer_output is None:
continue
# Get retrieval weights
retrieval_weights = layer_output.get('retrieval_weights')
if retrieval_weights is not None:
# Update importance based on average retrieval weight
# This is a simplified version - in practice, you'd need
# to track which specific memories were retrieved
pass
def _consolidate_memories(self):
"""Consolidate long-term memories across all layers."""
for layer in self.model.layers:
layer.long_term_memory.consolidate_memories()
def _get_memory_stats(self) -> Dict[str, float]:
"""Get statistics about memory usage."""
stats = {}
for i, layer in enumerate(self.model.layers):
# Short-term memory statistics
st_memory = layer.short_term_memory.memory
stats[f'layer_{i}_st_memory_mean'] = st_memory.mean().item()
stats[f'layer_{i}_st_memory_std'] = st_memory.std().item()
# Long-term memory statistics
lt_keys = layer.long_term_memory.memory_keys
lt_values = layer.long_term_memory.memory_values
lt_importance = layer.long_term_memory.memory_importance
stats[f'layer_{i}_lt_keys_mean'] = lt_keys.mean().item()
stats[f'layer_{i}_lt_importance_mean'] = torch.sigmoid(lt_importance).mean().item()
# Count active memories
active_count = (torch.sigmoid(lt_importance) > 0.1).sum().item()
stats[f'layer_{i}_lt_active_count'] = active_count
return stats
def train(self, train_dataset: MemoryDataset,
num_epochs: int,
batch_size: int = 32,
eval_dataset: Optional[MemoryDataset] = None) -> Dict:
"""
Train the model.
Args:
train_dataset: Training dataset
num_epochs: Number of training epochs
batch_size: Batch size
eval_dataset: Optional evaluation dataset
Returns:
Training history
"""
history = {
'train_loss': [],
'eval_loss': [],
'memory_stats': []
}
for epoch in range(num_epochs):
self.model.train()
epoch_loss = 0
num_batches = 0
# Create data loader
dataloader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True
)
for batch_idx, batch in enumerate(dataloader):
# Perform training step
stats = self.train_step(batch)
epoch_loss += stats['loss']
num_batches += 1
# Log progress
if batch_idx % 100 == 0:
print(f"Epoch {epoch+1}/{num_epochs}, "
f"Batch {batch_idx}/{len(dataloader)}, "
f"Loss: {stats['loss']:.4f}")
# Average epoch loss
avg_epoch_loss = epoch_loss / num_batches
history['train_loss'].append(avg_epoch_loss)
# Evaluate
if eval_dataset is not None:
eval_loss = self.evaluate(eval_dataset)
history['eval_loss'].append(eval_loss)
print(f"Epoch {epoch+1} - Train Loss: {avg_epoch_loss:.4f}, "
f"Eval Loss: {eval_loss:.4f}")
else:
print(f"Epoch {epoch+1} - Train Loss: {avg_epoch_loss:.4f}")
# Save best model
if avg_epoch_loss < self.best_loss:
self.best_loss = avg_epoch_loss
self.save_checkpoint('best_model.pt')
return history
def evaluate(self, dataset: MemoryDataset) -> float:
"""
Evaluate the model on a dataset.
Args:
dataset: Evaluation dataset
Returns:
Average loss
"""
self.model.eval()
total_loss = 0
num_batches = 0
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
with torch.no_grad():
for batch in dataloader:
# Move to device
device = next(self.model.parameters()).device
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
# Forward pass (no memory storage during eval)
outputs = self.model(
input_ids=input_ids,
labels=labels,
use_long_term_memory=True,
store_long_term=False
)
total_loss += outputs['loss'].item()
num_batches += 1
return total_loss / num_batches
def save_checkpoint(self, path: str):
"""
Save model checkpoint.
Args:
path: Path to save checkpoint
"""
checkpoint = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
'global_step': self.global_step,
'best_loss': self.best_loss,
'config': self.config,
'loss_history': self.loss_history
}
torch.save(checkpoint, path)
print(f"Checkpoint saved to {path}")
def load_checkpoint(self, path: str):
"""
Load model checkpoint.
Args:
path: Path to checkpoint file
"""
checkpoint = torch.load(path, map_location='cpu')
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.global_step = checkpoint['global_step']
self.best_loss = checkpoint['best_loss']
self.loss_history = checkpoint.get('loss_history', [])
print(f"Checkpoint loaded from {path}")
def create_synthetic_dataset(num_samples: int = 1000,
seq_len: int = 50,
vocab_size: int = 50257) -> List[Dict]:
"""
Create a synthetic dataset for testing.
Args:
num_samples: Number of samples
seq_len: Sequence length
vocab_size: Vocabulary size
Returns:
List of training examples
"""
dataset = []
for _ in range(num_samples):
# Generate random sequence
input_ids = torch.randint(0, vocab_size, (seq_len,))
# Create labels (shifted by 1 for next-token prediction)
labels = torch.cat([
input_ids[1:],
torch.zeros(1, dtype=input_ids.dtype)
], dim=0)
dataset.append({
'input_ids': input_ids,
'labels': labels
})
return dataset
if __name__ == '__main__':
# Create model
config = LMCODEConfig(
vocab_size=50257,
hidden_size=256, # Smaller for testing
num_layers=4,
num_heads=4,
short_term_memory_size=256,
long_term_memory_slots=1000
)
model = LMCODE(config)
# Create synthetic dataset
train_data = create_synthetic_dataset(num_samples=100, seq_len=32)
train_dataset = MemoryDataset(train_data, memory_sample_ratio=0.2)
# Create trainer
trainer_config = {
'learning_rate': 1e-4,
'weight_decay': 0.01,
'gradient_clip': 1.0,
'memory_consolidation_interval': 50,
'warmup_steps': 10,
'total_steps': 1000
}
trainer = MemoryAwareTrainer(model, trainer_config)
# Train for 2 epochs
print("Starting training...")
history = trainer.train(train_dataset, num_epochs=2, batch_size=8)
# Save model
trainer.save_checkpoint('lm_memory_model.pt')
print("Training complete!")
print(f"Final loss: {history['train_loss'][-1]:.4f}")