Instructions to use nvidia/E-RADIO with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use nvidia/E-RADIO with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="nvidia/E-RADIO", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("nvidia/E-RADIO", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from logging import getLogger | |
| import math | |
| import os | |
| from typing import Union, Tuple | |
| from types import MethodType | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torch.nn.utils import parametrize | |
| from torch.nn.utils.parametrizations import _SpectralNorm | |
| from timm.models.vision_transformer import Attention, Mlp | |
| _EPS = 1e-5 | |
| class _SNReweight(_SpectralNorm): | |
| def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, alpha: float = 0.05, version: int = 2, **kwargs): | |
| super().__init__(weight, *args, **kwargs) | |
| self.alpha = alpha | |
| self.version = version | |
| self.register_buffer('_sn_version', torch.tensor(version)) | |
| if init_norm_to_current: | |
| # This will set the numerator to match the denominator, which should preserve the original values | |
| init_scale = self._get_sigma(weight).item() | |
| else: | |
| init_scale = 1.0 | |
| if version == 1: | |
| init_value = init_scale | |
| elif version == 2: | |
| t = init_scale - alpha | |
| if t < _EPS: | |
| getLogger("spectral_reparam").warn(f'The initialized spectral norm {init_scale} is too small to be represented. Setting to {_EPS} instead.') | |
| t = _EPS | |
| init_value = math.log(math.exp(t) - 1) | |
| else: | |
| raise ValueError(f'Unsupported version: {version}') | |
| # Make 2D so that weight decay gets applied | |
| self.scale = nn.Parameter(torch.tensor([[init_value]], dtype=torch.float32, device=weight.device)) | |
| # Re-implementing this because we need to make division by sigma safe | |
| def _get_sigma(self, weight: torch.Tensor) -> torch.Tensor: | |
| if weight.ndim == 1: | |
| # Faster and more exact path, no need to approximate anything | |
| sigma = weight.norm() | |
| else: | |
| weight_mat = self._reshape_weight_to_matrix(weight) | |
| if self.training: | |
| self._power_method(weight_mat, self.n_power_iterations) | |
| # See above on why we need to clone | |
| u = self._u.clone(memory_format=torch.contiguous_format) | |
| v = self._v.clone(memory_format=torch.contiguous_format) | |
| # The proper way of computing this should be through F.bilinear, but | |
| # it seems to have some efficiency issues: | |
| # https://github.com/pytorch/pytorch/issues/58093 | |
| sigma = torch.dot(u, torch.mv(weight_mat, v)) | |
| return sigma + self.eps | |
| def forward(self, weight: torch.Tensor, *args, **kwargs): | |
| dtype = weight.dtype | |
| sigma = self._get_sigma(weight, *args, **kwargs) | |
| if self.version == 1: | |
| scale = self.scale | |
| elif self.version == 2: | |
| scale = F.softplus(self.scale) + self.alpha | |
| else: | |
| raise ValueError(f'Unsupported version: {self.version}') | |
| scale = scale.float() / sigma.float() | |
| y = weight * scale | |
| if dtype in (torch.float16, torch.bfloat16): | |
| y = y.to(dtype) | |
| return y | |
| def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): | |
| version_key = f'{prefix}_sn_version' | |
| if version_key not in state_dict: | |
| self.version = 1 | |
| state_dict[version_key] = torch.tensor(1) | |
| return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) | |
| class _AttnSNReweight(nn.Module): | |
| def __init__(self, weight: torch.Tensor, *args, init_norm_to_current: bool = False, renorm_values: bool = False, **kwargs): | |
| super().__init__() | |
| parts = weight.split(weight.shape[0] // 3, dim=0) | |
| ct = 2 if not renorm_values else 3 | |
| self.parts = nn.ModuleList([ | |
| _SNReweight(p, *args, init_norm_to_current=init_norm_to_current, **kwargs) if i < ct else nn.Identity() | |
| for i, p in enumerate(parts) | |
| ]) | |
| def forward(self, weight: torch.Tensor, *args, **kwargs): | |
| parts = weight.split(weight.shape[0] // 3, dim=0) | |
| parts = [ | |
| fn(p) | |
| for fn, p in zip(self.parts, parts) | |
| ] | |
| return torch.cat(parts, dim=0) | |
| def enable_spectral_reparam(model: nn.Module, | |
| n_power_iterations: int = 1, | |
| eps: float = 1e-6, | |
| init_norm_to_current: bool = False, | |
| renorm_values: bool = True, | |
| renorm_mlp: bool = True): | |
| # print('Enabling spectral reparametrization') | |
| for mod in model.modules(): | |
| if isinstance(mod, Attention): | |
| parametrize.register_parametrization( | |
| mod.qkv, | |
| 'weight', | |
| _AttnSNReweight(mod.qkv.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current, renorm_values=renorm_values), | |
| ) | |
| pass | |
| elif isinstance(mod, Mlp) and renorm_mlp: | |
| parametrize.register_parametrization( | |
| mod.fc1, | |
| 'weight', | |
| _SNReweight(mod.fc1.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current), | |
| ) | |
| parametrize.register_parametrization( | |
| mod.fc2, | |
| 'weight', | |
| _SNReweight(mod.fc2.weight, n_power_iterations, dim=0, eps=eps, init_norm_to_current=init_norm_to_current), | |
| ) | |
| pass | |
| def configure_spectral_reparam_from_args(model: nn.Module, args): | |
| spectral_reparam = getattr(args, 'spectral_reparam', False) | |
| if isinstance(spectral_reparam, bool) and spectral_reparam: | |
| enable_spectral_reparam(model, init_norm_to_current=args.pretrained) | |
| elif isinstance(spectral_reparam, dict): | |
| enable_spectral_reparam( | |
| model, | |
| n_power_iterations=spectral_reparam.get('n_power_iterations', 1), | |
| eps=spectral_reparam.get('eps', 1e-12), | |
| init_norm_to_current=args.pretrained, | |
| ) | |
| def disable_spectral_reparam(model: nn.Module): | |
| for mod in model.modules(): | |
| if isinstance(mod, Attention): | |
| parametrize.remove_parametrizations(mod.qkv, 'weight') | |
| pass | |
| elif isinstance(mod, Mlp): | |
| parametrize.remove_parametrizations(mod.fc1, 'weight') | |
| parametrize.remove_parametrizations(mod.fc2, 'weight') | |
| pass | |
| if __name__ == '__main__': | |
| import argparse | |
| from . import radio_model as create_model | |
| parser = argparse.ArgumentParser(description='Remove parametrization from state dict') | |
| parser.add_argument('--checkpoint', type=str, required=True, help='The checkpoint to load') | |
| parser.add_argument('--output', type=str, default='', help='Where to store the checkpoint') | |
| parser.add_argument('--release', default=False, action='store_true', help='Prune extraneous checkpoint fields') | |
| parser.add_argument('--strict', default=False, action='store_true', help='Strictly load the state dict') | |
| args = parser.parse_args() | |
| if not args.output: | |
| chk_dir, chk_name = os.path.split(args.checkpoint) | |
| args.output = os.path.join(chk_dir, f'clean_{chk_name}') | |
| print(f'Set output to "{args.output}"') | |
| chk = torch.load(args.checkpoint, map_location='cpu', mmap=True) | |
| model = create_model.create_model_from_args(chk['args']) | |
| key = 'base_model.' | |
| mod_state = dict() | |
| extra_state = dict() | |
| for k, v in chk['state_dict'].items(): | |
| if k.startswith(key): | |
| mod_state[k[len(key):]] = v | |
| else: | |
| extra_state[k] = v | |
| chk_load_info = model.load_state_dict(mod_state, strict=args.strict) | |
| if chk_load_info.unexpected_keys or chk_load_info.missing_keys: | |
| print(chk_load_info) | |
| if chk['args'].spectral_reparam: | |
| disable_spectral_reparam(model) | |
| if hasattr(chk['args'], 'dtype'): | |
| model.to(dtype=chk['args'].dtype) | |
| mod_state = model.state_dict() | |
| final_state = dict() | |
| final_state.update({f'{key}{k}': v for k, v in mod_state.items()}) | |
| final_state.update(extra_state) | |
| chk['state_dict'] = final_state | |
| chk['args'].spectral_reparam = False | |
| if args.release: | |
| chk = { | |
| 'arch': chk['arch'], | |
| 'epoch': chk['epoch'], | |
| 'state_dict': chk['state_dict'], | |
| 'args': chk['args'], | |
| } | |
| torch.save(chk, args.output) | |
| pass | |