RNABERT

A small BERT-style RNA language model pretrained on non-coding RNA sequences from Rfam 14.3, using Masked Language Modeling (MLM) and Structural Alignment Learning (SAL). Designed for RNA clustering and structural alignment tasks.

Architecture

Parameter Value
Layers 6
Attention heads 12
Embedding dimension 120
FFN intermediate size 40
Vocabulary size 6 (PAD, MASK, A, U, G, C)
Positional encoding Learned absolute
Architecture Post-LN BERT encoder
Max sequence length 440

Vocabulary:

Token ID
<pad> 0
<mask> 1
A 2
U 3
G 4
C 5

No CLS or EOS tokens are added. Sequences are tokenized character-by-character; T is silently converted to U.

Pretraining

  • Objective: Masked Language Modeling (MLM) + Structural Alignment Learning (SAL, a pairwise structural alignment contrastive objective)
  • Data: Rfam 14.3 (~440 nt max length sequences)
  • Source checkpoint: bert_mul_2.pth (distributed inside RNABERT_pretrained.pth zip, Google Drive)

Checkpoint selection

There is one published pretrained checkpoint from the original repository. This is it.

Parity Verification

Hidden-state representations verified identical (max abs diff = 3e-6) to the original implementation at all 7 representation levels (embedding + 6 transformer layers), with and without padding, for both eager and SDPA backends. Verified on GPU with PyTorch 2.7 / transformers 4.57.6.

Related Models

See the full RNABERT collection.

Model Notes
Taykhoom/RNABERT This model

Usage

Embedding generation

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNABERT", trust_remote_code=True)
model = AutoModel.from_pretrained("Taykhoom/RNABERT", trust_remote_code=True)
model.eval()

sequences = ["AUGCAUGCAUGC", "GCUAGCUAGCUA"]
enc = tokenizer(sequences, return_tensors="pt", padding=True)

with torch.no_grad():
    out = model(**enc)

# Token-level embeddings
token_emb = out.last_hidden_state   # (batch, seq_len, 120)

# Mean-pool over non-padding positions
mask = enc["attention_mask"].unsqueeze(-1).float()
mean_emb = (token_emb * mask).sum(1) / mask.sum(1)  # (batch, 120)

# Intermediate layers
out_all = model(**enc, output_hidden_states=True)
layer3_emb = out_all.hidden_states[3]   # (batch, seq_len, 120)

MLM logits

from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("Taykhoom/RNABERT", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("Taykhoom/RNABERT", trust_remote_code=True)
model.eval()

enc = tokenizer(["AUG<mask>AUG"], return_tensors="pt")
with torch.no_grad():
    logits = model(**enc).logits   # (1, seq_len, 6)

Fine-tuning

The model has no CLS token, so use mean pooling over non-padding positions for sequence-level tasks.

import torch.nn as nn
from transformers import AutoModel

model = AutoModel.from_pretrained("Taykhoom/RNABERT", trust_remote_code=True)

class RNAClassifier(nn.Module):
    def __init__(self, base, num_labels):
        super().__init__()
        self.base = base
        self.head = nn.Linear(120, num_labels)

    def forward(self, input_ids, attention_mask):
        out = self.base(input_ids, attention_mask=attention_mask)
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (out.last_hidden_state * mask).sum(1) / mask.sum(1)
        return self.head(pooled)

Implementation Notes

This port uses a standalone RNABertModel (custom PreTrainedModel subclass, model_type: "rnabert"). trust_remote_code=True is required for both the tokenizer and the model.

The original implementation uses standard scaled dot-product attention (post-LN BERT). This HF port adds attn_implementation="sdpa" and attn_implementation="flash_attention_2" support, which were not part of the original codebase.

# Faster inference with SDPA (default on modern PyTorch)
model = AutoModel.from_pretrained("Taykhoom/RNABERT", trust_remote_code=True,
                                   attn_implementation="sdpa")

# Flash Attention 2 (requires flash-attn installed)
model = AutoModel.from_pretrained("Taykhoom/RNABERT", trust_remote_code=True,
                                   attn_implementation="flash_attention_2")

Citation

@article{akiyama2022_rnabert,
  title   = {Informative {RNA} base embedding for {RNA} structural alignment and clustering by deep representation learning},
  author  = {Akiyama, Manato and Sakakibara, Yasubumi},
  journal = {NAR Genomics and Bioinformatics},
  volume  = {4},
  number  = {1},
  pages   = {lqac012},
  year    = {2022},
  doi     = {10.1093/nargab/lqac012}
}

Credits

Original model and code by Akiyama and Sakakibara. Source: GitHub. The HF conversion code was authored primarily by Claude Code and reviewed manually by Taykhoom Dalal.

License

No license is specified in the original repository. Please contact the authors before redistributing or using in commercial settings.

Downloads last month
109
Safetensors
Model size
478k params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Collection including Taykhoom/RNABERT