| """ |
| Utility functions for LMCODE (Language Model with Memory CODE). |
| |
| Includes memory visualization, analysis, and helper functions. |
| """ |
|
|
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from typing import Dict, List, Optional, Tuple |
| from collections import defaultdict |
| import json |
|
|
|
|
| def analyze_memory_capacity(model, test_sequences: List[torch.Tensor], |
| retrieval_threshold: float = 0.8) -> Dict: |
| """ |
| Analyze the memory capacity and retrieval accuracy of the model. |
| |
| Args: |
| model: LMCODE model |
| test_sequences: List of test sequences |
| retrieval_threshold: Similarity threshold for successful retrieval |
| |
| Returns: |
| Dictionary with analysis results |
| """ |
| results = { |
| 'total_memories': 0, |
| 'successful_retrievals': 0, |
| 'average_similarity': 0, |
| 'capacity_utilization': 0 |
| } |
| |
| similarities = [] |
| |
| for seq in test_sequences: |
| |
| model.store_experience(seq) |
| |
| |
| retrieved, indices = model.query_memory(seq, top_k=5) |
| |
| |
| with torch.no_grad(): |
| |
| seq_repr = seq.mean(dim=1) if seq.dim() > 2 else seq |
| retrieved_repr = retrieved.mean(dim=1) if retrieved.dim() > 2 else retrieved |
| |
| if seq_repr.shape[-1] == retrieved_repr.shape[-1]: |
| |
| seq_norm = torch.nn.functional.normalize(seq_repr, dim=-1) |
| retrieved_norm = torch.nn.functional.normalize(retrieved_repr, dim=-1) |
| similarity = (seq_norm * retrieved_norm).sum(dim=-1).mean() |
| similarities.append(similarity.item()) |
| |
| if similarity > retrieval_threshold: |
| results['successful_retrievals'] += 1 |
| |
| results['total_memories'] += 1 |
| |
| if similarities: |
| results['average_similarity'] = np.mean(similarities) |
| results['similarity_std'] = np.std(similarities) |
| |
| |
| total_slots = sum( |
| layer.long_term_memory.num_slots |
| for layer in model.layers |
| ) |
| results['capacity_utilization'] = results['total_memories'] / total_slots |
| |
| return results |
|
|
|
|
| def visualize_memory_attention(attention_weights: torch.Tensor, |
| save_path: Optional[str] = None) -> plt.Figure: |
| """ |
| Visualize memory attention patterns. |
| |
| Args: |
| attention_weights: Attention weights tensor |
| save_path: Optional path to save figure |
| |
| Returns: |
| Matplotlib figure |
| """ |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) |
| |
| |
| if len(attention_weights.shape) == 4: |
| |
| attention_avg = attention_weights.mean(dim=1) |
| else: |
| attention_avg = attention_weights |
| |
| im1 = axes[0].imshow(attention_avg[0].cpu().numpy(), cmap='hot', aspect='auto') |
| axes[0].set_title('Short-Term Memory Attention') |
| axes[0].set_xlabel('Memory Slots') |
| axes[0].set_ylabel('Sequence Position') |
| plt.colorbar(im1, ax=axes[0]) |
| |
| |
| |
| axes[1].text(0.5, 0.5, 'Long-Term Memory\nRetrieval Weights', |
| ha='center', va='center', transform=axes[1].transAxes) |
| axes[1].set_title('Long-Term Memory Retrieval') |
| |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| |
| return fig |
|
|
|
|
| def plot_training_history(history: Dict, save_path: Optional[str] = None) -> plt.Figure: |
| """ |
| Plot training history. |
| |
| Args: |
| history: Training history dictionary |
| save_path: Optional path to save figure |
| |
| Returns: |
| Matplotlib figure |
| """ |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) |
| |
| |
| axes[0].plot(history.get('train_loss', []), label='Train Loss', alpha=0.8) |
| if 'eval_loss' in history and history['eval_loss']: |
| axes[0].plot(history['eval_loss'], label='Eval Loss', alpha=0.8) |
| axes[0].set_xlabel('Epoch') |
| axes[0].set_ylabel('Loss') |
| axes[0].set_title('Training Loss') |
| axes[0].legend() |
| axes[0].grid(True, alpha=0.3) |
| |
| |
| if 'memory_stats' in history and history['memory_stats']: |
| memory_stats = history['memory_stats'] |
| if isinstance(memory_stats, list) and len(memory_stats) > 0: |
| |
| if isinstance(memory_stats[0], dict): |
| |
| layer_0_stats = [s.get('layer_0_lt_active_count', 0) |
| for s in memory_stats if isinstance(s, dict)] |
| if layer_0_stats: |
| axes[1].plot(layer_0_stats, label='Active Memories (Layer 0)', alpha=0.8) |
| axes[1].set_xlabel('Step') |
| axes[1].set_ylabel('Active Memory Count') |
| axes[1].set_title('Memory Utilization') |
| axes[1].legend() |
| axes[1].grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| |
| return fig |
|
|
|
|
| def compute_memory_efficiency(model, baseline_model=None) -> Dict: |
| """ |
| Compute memory efficiency metrics. |
| |
| Args: |
| model: LMCODE model |
| baseline_model: Optional baseline model for comparison |
| |
| Returns: |
| Dictionary with efficiency metrics |
| """ |
| metrics = {} |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| memory_params = sum( |
| p.numel() for name, p in model.named_parameters() |
| if 'memory' in name |
| ) |
| |
| metrics['total_parameters'] = total_params |
| metrics['memory_parameters'] = memory_params |
| metrics['memory_parameter_ratio'] = memory_params / total_params |
| |
| |
| total_memory_slots = sum( |
| layer.long_term_memory.num_slots |
| for layer in model.layers |
| ) |
| metrics['total_memory_slots'] = total_memory_slots |
| |
| |
| params_per_slot = memory_params / total_memory_slots if total_memory_slots > 0 else 0 |
| metrics['parameters_per_memory_slot'] = params_per_slot |
| |
| if baseline_model: |
| baseline_params = sum(p.numel() for p in baseline_model.parameters()) |
| metrics['parameter_savings'] = 1 - (total_params / baseline_params) |
| |
| return metrics |
|
|
|
|
| def generate_memory_report(model, dataset, output_path: str = 'memory_report.json'): |
| """ |
| Generate a comprehensive memory report. |
| |
| Args: |
| model: LMCODE model |
| dataset: Evaluation dataset (list of dicts or list of tensors) |
| output_path: Path to save report |
| """ |
| report = { |
| 'model_config': model.config.to_dict() if hasattr(model.config, 'to_dict') |
| else vars(model.config), |
| 'memory_analysis': {}, |
| 'efficiency_metrics': {} |
| } |
| |
| |
| if isinstance(dataset, list): |
| if len(dataset) > 0: |
| if isinstance(dataset[0], dict): |
| test_sequences = [d['input_ids'] for d in dataset[:10]] |
| else: |
| test_sequences = dataset[:10] |
| else: |
| test_sequences = [] |
| else: |
| test_sequences = [] |
| |
| memory_analysis = analyze_memory_capacity(model, test_sequences) |
| report['memory_analysis'] = memory_analysis |
| |
| |
| efficiency = compute_memory_efficiency(model) |
| report['efficiency_metrics'] = efficiency |
| |
| |
| with open(output_path, 'w') as f: |
| json.dump(report, f, indent=2, default=str) |
| |
| print(f"Memory report saved to {output_path}") |
| return report |
|
|
|
|
| def visualize_memory_flow(model, input_sequence: torch.Tensor, |
| save_path: Optional[str] = None) -> plt.Figure: |
| """ |
| Visualize memory flow through the network. |
| |
| Args: |
| model: LMCODE model |
| input_sequence: Input sequence |
| save_path: Optional path to save figure |
| |
| Returns: |
| Matplotlib figure |
| """ |
| model.eval() |
| |
| with torch.no_grad(): |
| outputs = model(input_sequence.unsqueeze(0), use_long_term_memory=True) |
| |
| fig, axes = plt.subplots(2, 2, figsize=(12, 10)) |
| |
| |
| hidden_states = outputs['hidden_states'] |
| if hidden_states: |
| |
| avg_hidden = torch.stack([hs.mean(dim=0) for hs in hidden_states]) |
| im1 = axes[0, 0].imshow(avg_hidden.cpu().numpy(), aspect='auto', cmap='viridis') |
| axes[0, 0].set_title('Hidden State Evolution Across Layers') |
| axes[0, 0].set_xlabel('Hidden Dimension') |
| axes[0, 0].set_ylabel('Layer') |
| plt.colorbar(im1, ax=axes[0, 0]) |
| |
| |
| if outputs['attention_weights']: |
| attn = outputs['attention_weights'][0] |
| if attn.dim() == 4: |
| attn = attn.mean(dim=1) |
| im2 = axes[0, 1].imshow(attn[0].cpu().numpy(), aspect='auto', cmap='plasma') |
| axes[0, 1].set_title('Self-Attention Weights (Layer 0)') |
| axes[0, 1].set_xlabel('Key Position') |
| axes[0, 1].set_ylabel('Query Position') |
| plt.colorbar(im2, ax=axes[0, 1]) |
| |
| |
| if outputs['long_term_outputs'] and outputs['long_term_outputs'][0]: |
| lt_output = outputs['long_term_outputs'][0] |
| if 'retrieval_weights' in lt_output: |
| weights = lt_output['retrieval_weights'] |
| im3 = axes[1, 0].imshow(weights[0].cpu().numpy(), aspect='auto', cmap='coolwarm') |
| axes[1, 0].set_title('Long-Term Memory Retrieval Weights') |
| axes[1, 0].set_xlabel('Memory Slot') |
| axes[1, 0].set_ylabel('Sequence Position') |
| plt.colorbar(im3, ax=axes[1, 0]) |
| |
| |
| if outputs['short_term_outputs']: |
| st_output = outputs['short_term_outputs'][0] |
| if 'read_weights' in st_output: |
| weights = st_output['read_weights'] |
| im4 = axes[1, 1].imshow(weights[0].cpu().numpy(), aspect='auto', cmap='YlOrRd') |
| axes[1, 1].set_title('Short-Term Memory Read Weights') |
| axes[1, 1].set_xlabel('Memory Slot') |
| axes[1, 1].set_ylabel('Sequence Position') |
| plt.colorbar(im4, ax=axes[1, 1]) |
| |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| |
| return fig |
|
|
|
|
| class MemoryMonitor: |
| """ |
| Monitor memory usage and performance during training. |
| """ |
| |
| def __init__(self, model): |
| self.model = model |
| self.history = defaultdict(list) |
| |
| def record_step(self, step: int, outputs: Dict): |
| """ |
| Record memory statistics for a training step. |
| |
| Args: |
| step: Current training step |
| outputs: Model outputs |
| """ |
| |
| if 'loss' in outputs and outputs['loss'] is not None: |
| self.history['loss'].append((step, outputs['loss'].item())) |
| |
| |
| for i, lt_output in enumerate(outputs.get('long_term_outputs', [])): |
| if lt_output and 'retrieval_weights' in lt_output: |
| weights = lt_output['retrieval_weights'] |
| avg_weight = weights.mean().item() |
| self.history[f'layer_{i}_retrieval_weight'].append((step, avg_weight)) |
| |
| for i, st_output in enumerate(outputs.get('short_term_outputs', [])): |
| if st_output and 'read_weights' in st_output: |
| weights = st_output['read_weights'] |
| avg_weight = weights.mean().item() |
| self.history[f'layer_{i}_read_weight'].append((step, avg_weight)) |
| |
| def get_statistics(self) -> Dict: |
| """ |
| Get aggregated memory statistics. |
| |
| Returns: |
| Dictionary with statistics |
| """ |
| stats = {} |
| |
| for key, values in self.history.items(): |
| if values: |
| vals = [v for _, v in values] |
| stats[key] = { |
| 'mean': np.mean(vals), |
| 'std': np.std(vals), |
| 'min': np.min(vals), |
| 'max': np.max(vals), |
| 'latest': vals[-1] |
| } |
| |
| return stats |
| |
| def plot_history(self, save_path: Optional[str] = None) -> plt.Figure: |
| """ |
| Plot monitoring history. |
| |
| Args: |
| save_path: Optional path to save figure |
| |
| Returns: |
| Matplotlib figure |
| """ |
| n_metrics = len(self.history) |
| if n_metrics == 0: |
| fig, ax = plt.subplots(1, 1, figsize=(8, 6)) |
| ax.text(0.5, 0.5, 'No data recorded', |
| ha='center', va='center', transform=ax.transAxes) |
| return fig |
| |
| n_cols = min(2, n_metrics) |
| n_rows = (n_metrics + n_cols - 1) // n_cols |
| |
| fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows)) |
| if n_metrics == 1: |
| axes = [axes] |
| elif n_rows > 1 and n_cols > 1: |
| axes = axes.flatten() |
| |
| for idx, (key, values) in enumerate(self.history.items()): |
| if idx >= len(axes): |
| break |
| |
| steps, vals = zip(*values) |
| axes[idx].plot(steps, vals, alpha=0.7) |
| axes[idx].set_title(key.replace('_', ' ').title()) |
| axes[idx].set_xlabel('Step') |
| axes[idx].set_ylabel('Value') |
| axes[idx].grid(True, alpha=0.3) |
| |
| |
| for idx in range(n_metrics, len(axes)): |
| axes[idx].axis('off') |
| |
| plt.tight_layout() |
| |
| if save_path: |
| plt.savefig(save_path, dpi=150, bbox_inches='tight') |
| |
| return fig |