Feature Extraction
Transformers
Joblib
Safetensors
MOJO
bulk RNA-seq
DNA methylation
biology
transcriptomics
epigenomics
multimodal
custom_code
Instructions to use InstaDeepAI/MOJO with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use InstaDeepAI/MOJO with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="InstaDeepAI/MOJO", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("InstaDeepAI/MOJO", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import logging | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Any, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F # noqa: N812 | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| class RotaryEmbeddingConfig: | |
| """ | |
| Parameters to initialize the RotaryEmbedding layer. The rescaling factor allows | |
| to adapt the rotary embeddings to larger lengths than what was used for training. | |
| One of this strategy is presented in the Yarn paper: https://arxiv.org/pdf/2309.00071.pdf. # noqa | |
| Args:b | |
| """ | |
| rescaling_factor: Optional[float] | |
| class RotaryEmbedding(torch.nn.Module): | |
| """ | |
| Rotary position embeddings based on those in | |
| [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). | |
| Query and keys are transformed by rotation | |
| matrices which depend on their relative positions. | |
| """ | |
| def __init__(self, dim: int, rotary_embedding_config: RotaryEmbeddingConfig): | |
| super().__init__() | |
| # Extract argument from the config | |
| self.rescaling_factor = rotary_embedding_config.rescaling_factor | |
| self.upper_freq = 10000 | |
| self.dim = dim | |
| self._seq_len_cached = None | |
| self._cos_cached = None | |
| self._sin_cached = None | |
| def _apply_rotary_pos_emb( | |
| self, | |
| heads: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ """ | |
| x_first, x_second = ( | |
| heads[..., : heads.shape[-1] // 2], | |
| heads[..., heads.shape[-1] // 2 :], | |
| ) | |
| first_part = x_first * cos - x_second * sin | |
| second_part = x_second * cos + x_first * sin | |
| return torch.cat((first_part, second_part), dim=-1) | |
| def _compute_cos_sin_tables( | |
| self, x: torch.Tensor, inv_freq: torch.Tensor, seq_dimension: int = 2 | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| seq_len = x.shape[seq_dimension] | |
| # Reset the tables if the sequence length has changed, | |
| # or if we're on a new device (possibly due to tracing for instance) | |
| self._seq_len_cached = seq_len | |
| t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(inv_freq) | |
| freqs = torch.einsum("i, j -> ij", t, inv_freq) | |
| self._cos_cached = torch.cos(freqs)[None, :, None, :] | |
| self._sin_cached = torch.sin(freqs)[None, :, None, :] | |
| return self._cos_cached, self._sin_cached | |
| def forward( | |
| self, q: torch.Tensor, k: torch.Tensor | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| if self.rescaling_factor is None: | |
| inv_freq = 1.0 / ( | |
| self.upper_freq | |
| ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) | |
| ) | |
| else: | |
| updated_base = self.upper_freq * ( | |
| self.rescaling_factor ** (self.dim / (self.dim - 2)) | |
| ) | |
| inv_freq = 1.0 / ( | |
| updated_base | |
| ** (torch.arange(0, self.dim, 2, device=q.device).float() / self.dim) | |
| ) | |
| self._cos_cached, self._sin_cached = self._compute_cos_sin_tables( | |
| q, | |
| inv_freq, | |
| seq_dimension=-3, | |
| ) | |
| return ( | |
| self._apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), | |
| self._apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), | |
| ) | |
| class ResidualConvBlock(nn.Module): | |
| """ | |
| Conv Block with Residual connection. | |
| """ | |
| def __init__( | |
| self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1 | |
| ): | |
| super().__init__() | |
| self.conv_block = ConvBlock( | |
| dim_in=dim_in, | |
| dim_out=dim_out, | |
| layer_norm_shape=layer_norm_shape, | |
| kernel_size=kernel_size, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| y = self.conv_block(x) | |
| return x.reshape(y.shape) + y | |
| class ConvBlock(nn.Module): | |
| """ | |
| Conv Block. | |
| """ | |
| def __init__( | |
| self, dim_in: int, dim_out: int, layer_norm_shape: int, kernel_size: int = 1 | |
| ): | |
| super().__init__() | |
| self.conv = nn.Conv1d( | |
| in_channels=dim_in, | |
| out_channels=dim_out, | |
| kernel_size=kernel_size, | |
| padding="same", | |
| ) | |
| self.layer_norm = nn.LayerNorm(layer_norm_shape, eps=1e-5) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x.permute(0, 2, 1) | |
| x = self.layer_norm(x) | |
| x = x.permute(0, 2, 1) | |
| x = self.conv(x) | |
| x = F.gelu(x, approximate="tanh") | |
| return x | |
| class ConvTowerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim_in: int, | |
| dim_out: int, | |
| conv_layer_norm_shape: int, | |
| resconv_layer_norm_shape, | |
| kernel_size: int, | |
| ) -> None: | |
| super().__init__() | |
| self.conv_layer = ConvBlock( | |
| dim_in=dim_in, | |
| dim_out=dim_out, | |
| layer_norm_shape=conv_layer_norm_shape, | |
| kernel_size=kernel_size, | |
| ) | |
| self.res_conv = ResidualConvBlock( | |
| dim_in=dim_out, | |
| dim_out=dim_out, | |
| layer_norm_shape=resconv_layer_norm_shape, | |
| kernel_size=1, | |
| ) | |
| self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| residual = x | |
| x = self.conv_layer(x) | |
| x = self.res_conv(x) | |
| x = self.avg_pool(x) | |
| return x, residual | |
| class ResidualDeConvBlock(nn.Module): | |
| """ | |
| Conv Block with Residual connection. | |
| """ | |
| def __init__( | |
| self, | |
| dim_in: int, | |
| dim_out: int, | |
| layer_norm_shape: int, | |
| kernel_size: int = 1, | |
| stride: int = 1, | |
| ): | |
| super().__init__() | |
| self.deconv_block = DeConvBlock( | |
| dim_in=dim_in, | |
| dim_out=dim_out, | |
| layer_norm_shape=layer_norm_shape, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| y = self.deconv_block(x) | |
| return x.reshape(y.shape) + y | |
| class DeConvBlock(nn.Module): | |
| """ | |
| DeConv Block. | |
| """ | |
| def __init__( | |
| self, | |
| dim_in: int, | |
| dim_out: int, | |
| layer_norm_shape: int, | |
| kernel_size: int = 1, | |
| stride: int = 1, | |
| ): | |
| super().__init__() | |
| self.deconv = nn.ConvTranspose1d( | |
| in_channels=dim_in, | |
| out_channels=dim_out, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=0, | |
| ) | |
| self.layer_norm = nn.LayerNorm(layer_norm_shape) | |
| self.kernel_size = kernel_size | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x.permute(0, 2, 1) | |
| x = self.layer_norm(x) | |
| x = x.permute(0, 2, 1) | |
| x = self.deconv(x) | |
| if self.kernel_size == 5: | |
| # handle the special case where haiku | |
| # deconv removes padding automatically | |
| x = x[:, :, 1:-2] | |
| x = F.gelu(x, approximate="tanh") | |
| return x | |
| class DeConvTowerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim_in: int, | |
| dim_out: int, | |
| kernel_size: int, | |
| conv_layer_norm_shape: int, | |
| resconv_layer_norm_shape: int, | |
| stride: int = 2, | |
| ): | |
| super().__init__() | |
| self.deconv_block = DeConvBlock( | |
| dim_in=dim_in, | |
| dim_out=dim_out, | |
| layer_norm_shape=conv_layer_norm_shape, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| ) | |
| self.res_deconv_block = ResidualDeConvBlock( | |
| dim_in=dim_out, | |
| dim_out=dim_out, | |
| layer_norm_shape=resconv_layer_norm_shape, | |
| kernel_size=1, | |
| ) | |
| def forward(self, x: torch.Tensor, res: torch.Tensor) -> torch.Tensor: | |
| x = self.deconv_block(x) | |
| x = self.res_deconv_block(x) | |
| x = x + res | |
| return x | |
| class MultiHeadAttention(nn.Module): | |
| def __init__( | |
| self, | |
| num_heads: int, | |
| key_size: int, | |
| rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, | |
| add_bias_kv: bool = False, | |
| value_size: Optional[int] = None, | |
| model_size: Optional[int] = None, | |
| name: Optional[str] = None, | |
| ): | |
| super().__init__() | |
| if not model_size: | |
| model_size = key_size | |
| if not value_size: | |
| value_size = key_size | |
| self.model_size = model_size | |
| self.key_size = key_size | |
| self.value_size = value_size | |
| self.add_bias_kv = add_bias_kv | |
| self.name = name | |
| self.num_heads = num_heads | |
| self._rotary_embedding_config = rotary_embedding_config | |
| self.w_k = nn.Linear(self.model_size, self.num_heads * self.key_size) | |
| self.w_q = nn.Linear(self.model_size, self.num_heads * self.key_size) | |
| self.w_v = nn.Linear(self.model_size, self.num_heads * self.value_size) | |
| self.output = nn.Linear(self.num_heads * self.value_size, self.model_size) | |
| if self._rotary_embedding_config: | |
| self._rotary_embedding = RotaryEmbedding( | |
| self.key_size, self._rotary_embedding_config | |
| ) | |
| def apply_rotary_embeddings( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| """ """ | |
| query, key = self._rotary_embedding(query, key) | |
| return query, key | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| attention_weight_bias: Optional[torch.Tensor] = None, | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Returns: | |
| dictionary containing attention weights | |
| and outputs. | |
| """ | |
| key_heads = self.w_k(key).reshape( | |
| (*key.shape[:-1], self.num_heads, self.key_size) | |
| ) | |
| query_heads = self.w_q(query).reshape( | |
| (*query.shape[:-1], self.num_heads, self.key_size) | |
| ) | |
| value_heads = self.w_v(value).reshape( | |
| (*value.shape[:-1], self.num_heads, self.value_size) | |
| ) | |
| if self._rotary_embedding_config: | |
| query_heads, key_heads = self.apply_rotary_embeddings( | |
| query_heads, key_heads | |
| ) | |
| attention_weights = torch.einsum( | |
| "...thd, ...Thd -> ...htT", query_heads, key_heads | |
| ) | |
| sqrt_key_size = np.sqrt(self.key_size) | |
| attention_weights = attention_weights / sqrt_key_size | |
| if attention_mask: | |
| attention_weights = torch.where(attention_mask, attention_weights, -1e30) | |
| if attention_weight_bias: | |
| attention_weights = F.softmax( | |
| attention_weights + attention_weight_bias, dim=-1 | |
| ) | |
| else: | |
| attention_weights = F.softmax(attention_weights, dim=-1) | |
| value_out = torch.einsum( | |
| "...htT, ...Thd->...thd", attention_weights, value_heads | |
| ) | |
| value_out = value_out.reshape((*value_out.shape[:-2], -1)) | |
| embeddings = self.output(value_out) | |
| return {"attention_weights": attention_weights, "embeddings": embeddings} | |
| class SelfAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| num_heads: int, | |
| embed_dim: int, | |
| ffn_embed_dim: int, | |
| key_size: Optional[int] = None, | |
| add_bias_kv: bool = False, | |
| add_bias_fnn: bool = True, | |
| ffn_activation_name: str = "gelu-no-approx", | |
| use_glu_in_ffn: bool = False, | |
| layer_norm_eps: float = 1e-5, # this is the default haiku value | |
| pre_layer_norm: bool = True, | |
| name: Optional[str] = None, | |
| rotary_embedding_config: Optional[RotaryEmbeddingConfig] = None, | |
| ): | |
| super().__init__() | |
| if key_size is None: | |
| if embed_dim % num_heads != 0: | |
| raise ValueError( | |
| f"The embedding dimension should be divisible by the number of " | |
| f"heads, however provided embedding dimension is {embed_dim} and " | |
| f"the number of heads is {num_heads}." | |
| ) | |
| else: | |
| key_size = embed_dim // num_heads | |
| # Get ffn activation function | |
| self._pre_layer_norm = pre_layer_norm | |
| self._use_glu_in_fnn = use_glu_in_ffn | |
| # Define layers | |
| if use_glu_in_ffn: | |
| # user should multiply ffn_embed_dim by 2/3 when using GLU | |
| # to keep total number of parameters equal | |
| # see https://arxiv.org/pdf/2002.05202.pdf. for more details | |
| # we multiply by 2 here as the output will be split in 2 for GLU | |
| self.fc1 = nn.Linear(embed_dim, int(2 * ffn_embed_dim), bias=add_bias_fnn) | |
| else: | |
| self.fc1 = nn.Linear(embed_dim, ffn_embed_dim, bias=add_bias_fnn) | |
| self.fc2 = nn.Linear(ffn_embed_dim, embed_dim, bias=add_bias_fnn) | |
| self.layer_norm_self_attention = nn.LayerNorm( | |
| embed_dim, | |
| ) | |
| self.layer_norm_mlp = nn.LayerNorm(embed_dim) | |
| if ffn_activation_name == "swish": | |
| self._ffn_activation_fn = nn.SiLU() | |
| elif ffn_activation_name == "gelu-no-approx": | |
| self._ffn_activation_fn = lambda x: F.gelu(x, approximate="none") | |
| else: | |
| self._ffn_activation_fn = getattr(torch.nn, ffn_activation_name) | |
| self.mha = MultiHeadAttention( | |
| num_heads=num_heads, | |
| key_size=key_size, | |
| add_bias_kv=add_bias_kv, | |
| model_size=embed_dim, | |
| name="self_attention", | |
| rotary_embedding_config=rotary_embedding_config, | |
| ) | |
| def mlp(self, embed: torch.Tensor) -> torch.Tensor: | |
| if self._pre_layer_norm: | |
| x = self.layer_norm_mlp(embed) | |
| else: | |
| x = embed | |
| if self._use_glu_in_fnn: | |
| x = self.fc1(x) | |
| x1, x2 = torch.split(x, split_size_or_sections=x.shape[-1] // 2, dim=-1) | |
| x = self._ffn_activation_fn(x1) * x2 | |
| else: | |
| x = self._ffn_activation_fn(self.fc1(x)) | |
| x = self.fc2(x) | |
| if not self._pre_layer_norm: | |
| x = self.layer_norm_mlp(x + embed) | |
| return x | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| attention_weight_bias: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| res = x | |
| if self._pre_layer_norm: | |
| x = self.layer_norm_self_attention(x) | |
| output = self.mha( | |
| x, | |
| x, | |
| x, | |
| attention_mask=attention_mask, | |
| attention_weight_bias=attention_weight_bias, | |
| ) | |
| if not self._pre_layer_norm: | |
| output["embeddings"] = self.layer_norm_self_attention( | |
| output["embeddings"] + res | |
| ) | |
| x = output["embeddings"] | |
| else: | |
| x = output["embeddings"] | |
| x = res + x | |
| # MLP | |
| if not self._pre_layer_norm: | |
| x = self.mlp(x) | |
| else: | |
| x = x + self.mlp(x) | |
| output["embeddings"] = x | |
| return output | |
| class LMHead(nn.Module): | |
| def __init__( | |
| self, dim_in: int, embed_dim: int, dim_out: int, num_hidden_layers: int | |
| ) -> None: | |
| """ """ | |
| super().__init__() | |
| self.num_hidden_layers = num_hidden_layers | |
| self.linear_layers = nn.ModuleList([nn.Linear(dim_in, embed_dim)]) | |
| self.linear_layers.extend( | |
| nn.ModuleList( | |
| [nn.Linear(embed_dim, embed_dim)] # noqa | |
| for _ in range(num_hidden_layers - 1) | |
| ) | |
| ) | |
| self.linear_out = nn.Linear(embed_dim, dim_out) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = F.gelu(x, approximate="tanh") | |
| for layer in self.linear_layers: | |
| x = layer(x) | |
| x = F.gelu(x, approximate="tanh") | |
| out = self.linear_out(x) | |
| return out | |
| class MOJOConfig(PretrainedConfig): # noqa: N801 | |
| model_type = "MOJO" | |
| def __init__(self, **kwargs: Any) -> None: | |
| super().__init__(**kwargs) | |
| self.alphabet_size = kwargs.get( | |
| "alphabet_size", {"rnaseq": 66, "methylation": 66} | |
| ) | |
| self.token_embed_dim = kwargs.get("token_embed_dim", 256) | |
| self.init_gene_embed_dim = kwargs.get("init_gene_embed_dim", 200) | |
| self.use_gene_embedding = kwargs.get("use_gene_embedding", True) | |
| self.project_gene_embedding = kwargs.get("project_gene_embedding", True) | |
| self.sequence_length = kwargs.get("sequence_length", 17_116) # n_genes | |
| self.fixed_sequence_length = kwargs.get("fixed_sequence_length", None) | |
| self.num_downsamples = kwargs.get("num_downsamples", 8) | |
| self.conv_init_embed_dim = kwargs.get("conv_init_embed_dim", 512) | |
| self.stem_kernel_shape = kwargs.get("stem_kernel_shape", 15) | |
| self.embed_dim = kwargs.get("embed_dim", 512) | |
| self.filter_list = kwargs.get("filter_list", []) | |
| self.num_attention_heads = kwargs.get("num_attention_heads", 16) | |
| self.key_size = kwargs.get("key_size", None) | |
| self.ffn_embed_dim = kwargs.get("ffn_embed_dim", 1_024) | |
| self.num_layers = kwargs.get("num_layers", 8) | |
| self.num_hidden_layers_head = kwargs.get("num_hidden_layers_head", 1) | |
| # return | |
| self.embeddings_layers_to_save: tuple[int, ...] = kwargs.get( | |
| "embeddings_layers_to_save", () | |
| ) | |
| self.attention_maps_to_save: list[tuple[int, int]] = kwargs.get( | |
| "attention_maps_to_save", [] | |
| ) | |
| self.__post_init__() | |
| def __post_init__(self): | |
| # Validate attention key size | |
| key_size = self.key_size | |
| if key_size is None: | |
| embed_dim = self.embed_dim | |
| num_attention_heads = self.num_attention_heads | |
| if not embed_dim % num_attention_heads == 0: | |
| raise ValueError( | |
| f"When no key size is provided, the embedding dimension should be " | |
| f"divisible by the number of heads, however provided embedding " | |
| f"dimension is {embed_dim} and the number of heads is " | |
| f"{num_attention_heads}." | |
| ) | |
| self.key_size = embed_dim // num_attention_heads | |
| # Validate gene embedding projection | |
| use_gene_embedding = self.use_gene_embedding | |
| if use_gene_embedding: | |
| init_gene_embed_dim = self.init_gene_embed_dim | |
| token_embed_dim = self.token_embed_dim | |
| if init_gene_embed_dim != token_embed_dim: | |
| project_gene_embedding = self.project_gene_embedding | |
| if not project_gene_embedding: | |
| logging.warning( | |
| f"Init gene embedding dimension ({init_gene_embed_dim})" | |
| f"different than token embedding dimension ({token_embed_dim})." | |
| f"Setting `project_gene_embedding` to True" | |
| ) | |
| self.project_gene_embedding = True | |
| # Compute fixed_sequence_length | |
| num_downsamples = self.num_downsamples | |
| sequence_length = self.sequence_length | |
| downsample_factor = 2**num_downsamples | |
| fixed_sequence_length = ( | |
| math.ceil(sequence_length / downsample_factor) * downsample_factor | |
| ) | |
| self.fixed_sequence_length = fixed_sequence_length | |
| # Create filters list | |
| num_downsamples = self.num_downsamples | |
| filter_list = ( | |
| np.linspace( | |
| self.conv_init_embed_dim, | |
| self.embed_dim, | |
| num_downsamples + 1, | |
| ) | |
| .astype(int) | |
| .tolist() | |
| ) | |
| self.filter_list = filter_list # noqa | |
| class MOJO(PreTrainedModel): # noqa: N801 | |
| config_class = MOJOConfig | |
| def __init__(self, config: MOJOConfig): | |
| super().__init__(config=config) | |
| # Embeddings | |
| self.embedding_layers = nn.ModuleDict( | |
| { | |
| omic: nn.Embedding(config.alphabet_size[omic], config.token_embed_dim) | |
| for omic in config.alphabet_size | |
| } | |
| ) | |
| self.gene_embedding_layer = nn.Embedding( | |
| config.fixed_sequence_length, | |
| config.init_gene_embed_dim, | |
| ) | |
| self.fc_gene_embedding = nn.Linear( | |
| config.init_gene_embed_dim, config.token_embed_dim | |
| ) | |
| # Convolutions | |
| self.stem_conv = nn.Sequential( | |
| nn.Conv1d( | |
| in_channels=config.token_embed_dim, | |
| out_channels=config.conv_init_embed_dim, | |
| kernel_size=config.stem_kernel_shape, | |
| padding="same", | |
| ), | |
| nn.GELU(approximate="tanh"), | |
| ) | |
| self.conv_tower = nn.ModuleList( | |
| [ | |
| ConvTowerBlock( | |
| dim_in=config.filter_list[i], | |
| dim_out=config.filter_list[i + 1], | |
| kernel_size=5, | |
| conv_layer_norm_shape=config.filter_list[i], | |
| resconv_layer_norm_shape=config.filter_list[i + 1], | |
| ) | |
| for i in range(len(config.filter_list) - 1) | |
| ] | |
| ) | |
| # Transformer | |
| attention_maps_to_save = config.attention_maps_to_save | |
| self._attention_layers_to_save = list({t[0] for t in attention_maps_to_save}) | |
| self._attention_maps_per_layer_to_save = { | |
| layer: [t[1] for t in attention_maps_to_save if t[0] == layer] | |
| for layer in self._attention_layers_to_save | |
| } | |
| max_layer = max(self._attention_layers_to_save + [0]) | |
| if max_layer > config.num_layers: | |
| raise ValueError( | |
| f"You are requiring attention maps for layer {max_layer}, " | |
| f"while the model has {config.num_layers} layers only." | |
| ) | |
| self._rotary_embedding_config = RotaryEmbeddingConfig(rescaling_factor=None) | |
| self.transformer_layers = nn.ModuleList( | |
| [ | |
| SelfAttentionBlock( | |
| num_heads=config.num_attention_heads, | |
| embed_dim=config.embed_dim, | |
| ffn_embed_dim=config.ffn_embed_dim, | |
| key_size=config.key_size, | |
| add_bias_kv=False, | |
| add_bias_fnn=False, | |
| ffn_activation_name="swish", | |
| use_glu_in_ffn=True, | |
| layer_norm_eps=1e-5, # this is the default haiku value | |
| pre_layer_norm=True, | |
| name=f"attention_layer_{layer_idx}", | |
| rotary_embedding_config=self._rotary_embedding_config, | |
| ) | |
| for layer_idx in range(config.num_layers) | |
| ] | |
| ) | |
| # Deconvolutions | |
| self.deconv_tower = nn.ModuleList( | |
| [ | |
| DeConvTowerBlock( | |
| dim_in=config.filter_list[-1 - i], | |
| dim_out=config.filter_list[-1 - i - 1], | |
| kernel_size=5, | |
| stride=2, | |
| conv_layer_norm_shape=config.filter_list[-1 - i], | |
| resconv_layer_norm_shape=config.filter_list[-1 - i - 1], | |
| ) | |
| for i in range(len(config.filter_list) - 1) | |
| ] | |
| ) | |
| # Language Modeling heads | |
| self.omic_lm_heads = nn.ModuleDict( | |
| { | |
| omic: LMHead( | |
| dim_in=config.conv_init_embed_dim, | |
| embed_dim=config.embed_dim, | |
| dim_out=config.alphabet_size[omic], | |
| num_hidden_layers=config.num_hidden_layers_head, | |
| ) | |
| for omic in self.config.alphabet_size | |
| } | |
| ) | |
| def get_embeddings( | |
| self, | |
| input_ids: dict[str, torch.Tensor], | |
| ) -> dict[str, torch.Tensor]: | |
| omic_embeddings = {} | |
| for omic, omic_tokens in input_ids.items(): | |
| omic_embeddings[omic] = self.embedding_layers[omic](omic_tokens) | |
| return omic_embeddings | |
| def forward(self, input_ids: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: | |
| outs = {} | |
| embeddings = self.get_embeddings(input_ids) | |
| outs["omic_embeddings"] = embeddings | |
| x = torch.stack(list(embeddings.values()), dim=0).sum(dim=0) # [B, T, C] | |
| outs["embeddings"] = x | |
| if self.config.use_gene_embedding: | |
| gene_indices = torch.arange( | |
| self.config.fixed_sequence_length, device=x.device | |
| ) | |
| gene_embedding = self.gene_embedding_layer(gene_indices) | |
| if self.config.project_gene_embedding: | |
| gene_embedding = self.fc_gene_embedding(gene_embedding) | |
| x = x + gene_embedding | |
| outs["embeddings_with_gene_embedding"] = x | |
| x = x.permute(0, 2, 1) | |
| x = self.stem_conv(x) | |
| outs["stem"] = x | |
| residuals = [] | |
| for conv_block in self.conv_tower: | |
| x, res = conv_block(x) | |
| residuals.append(res) | |
| x = x.permute(0, 2, 1) | |
| outs["conv_tower"] = x | |
| outs["conv_tower_residuals"] = residuals # type: ignore | |
| residuals = residuals[::-1] | |
| for layer_idx, transformer in enumerate(self.transformer_layers): | |
| output = transformer(x) | |
| x = output["embeddings"] | |
| if (layer_idx + 1) in self.config.embeddings_layers_to_save: | |
| outs[f"embeddings_{(layer_idx + 1)}"] = output["embeddings"] | |
| if (layer_idx + 1) in self._attention_layers_to_save: | |
| for map_number in self._attention_maps_per_layer_to_save[layer_idx + 1]: | |
| dkey = f"attention_map_layer_{layer_idx + 1}_number_{map_number}" | |
| outs[dkey] = output["attention_weights"][:, map_number + 1] | |
| outs["after_transformer_embedding"] = x | |
| x = x.permute(0, 2, 1) | |
| for deconv_block, res in zip(self.deconv_tower, residuals): | |
| x = deconv_block(x, res) | |
| x = x.permute(0, 2, 1) | |
| outs["deconv_tower"] = x | |
| outs["logits"] = { | |
| omic: self.omic_lm_heads[omic](x) for omic in self.config.alphabet_size | |
| } | |
| return outs | |