Skip to content
463 changes: 149 additions & 314 deletions src/transformers/models/esm/modeling_esm.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions src/transformers/models/esm/modeling_esmfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -1991,6 +1991,10 @@ def distogram(coords, min_bin, max_bin, num_bins):
class EsmForProteinFolding(EsmPreTrainedModel):
_no_split_modules = ["EsmFoldStructureModule", "EsmFoldTriangularSelfAttentionBlock"]
_supports_flash_attn = False
_supports_sdpa = False
_supports_attention_backend = False

_can_record_outputs = None

def __init__(self, config):
super().__init__(config)
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/models/evolla/configuration_evolla.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def __init__(
initializer_range=0.02,
layer_norm_eps=1e-05,
position_embedding_type="rotary",
use_cache=True,
emb_layer_norm_before=False,
token_dropout=True,
**kwargs,
Expand All @@ -94,7 +93,6 @@ def __init__(
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.emb_layer_norm_before = emb_layer_norm_before
self.token_dropout = token_dropout

Expand Down
397 changes: 118 additions & 279 deletions src/transformers/models/evolla/modeling_evolla.py

Large diffs are not rendered by default.

25 changes: 19 additions & 6 deletions src/transformers/models/evolla/modular_evolla.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
logging,
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import check_model_inputs
from ...utils.generic import OutputRecorder, check_model_inputs
from ..esm.modeling_esm import (
EsmAttention,
EsmEmbeddings,
Expand Down Expand Up @@ -122,13 +122,13 @@ def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch
self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)

return (
apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached),
apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
)


class EvollaSaProtSelfAttention(EsmSelfAttention, nn.Module):
def __init__(self, config, position_embedding_type=None, layer_idx=None):
def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
nn.Module.__init__(self)
self.config = config

Expand All @@ -146,7 +146,7 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None):
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.dropout = config.attention_probs_dropout_prob
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
Expand All @@ -159,6 +159,8 @@ def __init__(self, config, position_embedding_type=None, layer_idx=None):

self.is_decoder = config.is_decoder
self.layer_idx = layer_idx
self.scaling = 1.0
self.is_causal = self.is_decoder and not is_cross_attention


class EvollaSaProtSelfOutput(EsmSelfOutput):
Expand Down Expand Up @@ -193,6 +195,17 @@ class EvollaSaProtPooler(EsmPooler):
class EvollaSaProtPreTrainedModel(PreTrainedModel):
config: SaProtConfig
_no_split_modules = ["EvollaSaProtLayer"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_attention_backend = True

_can_record_outputs = {
"hidden_states": EvollaSaProtLayer,
"attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")],
"cross_attentions": [
OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"),
],
}

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down Expand Up @@ -230,7 +243,7 @@ class PreTrainedModel
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

@can_return_tuple
@check_model_inputs
def forward(
self,
input_ids: Optional[torch.Tensor],
Expand Down
22 changes: 11 additions & 11 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,22 +974,22 @@ def check_model_inputs(func):

@wraps(func)
def wrapper(self, *args, **kwargs):
use_cache = kwargs.get("use_cache")
if use_cache is None:
use_cache = getattr(self.config, "use_cache", False)
use_cache = (
kwargs["use_cache"] if kwargs.get("use_cache") is not None else getattr(self.config, "use_cache", None)
)
if use_cache is not None:
if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False

kwargs["use_cache"] = use_cache

return_dict = kwargs.pop("return_dict", None)
if return_dict is None:
return_dict = getattr(self.config, "return_dict", True)

if getattr(self, "gradient_checkpointing", False) and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False

kwargs["use_cache"] = use_cache

all_args = kwargs.copy()
if "kwargs" in all_args:
for k, v in all_args["kwargs"].items():
Expand Down
1 change: 1 addition & 0 deletions tests/models/esm/test_modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
config_and_inputs[0]._attn_implementation = "eager"
self.model_tester.create_and_check_model(*config_and_inputs)

def test_for_masked_lm(self):
Expand Down
1 change: 0 additions & 1 deletion tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4391,7 +4391,6 @@ def test_flash_attention_2_continue_generate_with_position_ids(self):
next_token_logits_from_generate = generation_out.logits[-1]

# acceptable numerical instability
# print(next_token_logits_from_generate, next_token_logits)
tol = torch.finfo(torch.bfloat16).eps
torch.testing.assert_close(next_token_logits_from_generate, next_token_logits, rtol=tol, atol=tol)

Expand Down