Skip to content
14 changes: 8 additions & 6 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

logger = logging.get_logger(__name__)
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
NEED_SETUP_CACHE_CLASSES_MAPPING = {}
STATIC_CACHE_CLASSES_MAPPING = {}
QUANT_BACKEND_CLASSES_MAPPING = {}
ALL_CACHE_IMPLEMENTATIONS = []

Expand All @@ -60,7 +60,7 @@
)
from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor

NEED_SETUP_CACHE_CLASSES_MAPPING = {
STATIC_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
"offloaded_static": OffloadedStaticCache,
"sliding_window": SlidingWindowCache,
Expand All @@ -70,7 +70,7 @@
"offloaded_hybrid_chunked": OffloadedHybridCache,
}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
ALL_CACHE_IMPLEMENTATIONS = list(NEED_SETUP_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"]
ALL_CACHE_IMPLEMENTATIONS = list(STATIC_CACHE_CLASSES_MAPPING.keys()) + ["offloaded", "dynamic", "quantized"]


class GenerationMode(ExplicitEnum):
Expand Down Expand Up @@ -1536,8 +1536,10 @@ class CompileConfig:
See [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) for more details on the arguments.

Args:
fullgraph (`bool`, *optional*, defaults to `True`):
If `True`, requires that the whole forward be capturable in a single graph.
fullgraph (`bool`, *optional*, defaults to `False`):
If False (default), attempts to discover compileable regions that will be optimized. If True, then require
that the entire function be capturable into a single graph. If this is not possible (that is, if there are
graph breaks), then an error will be raised.
dynamic (`bool` or `None`, *optional*):
Whether to try to use dynamic shape graphs.
backend (`str` or `Callable`, *optional*, defaults to `"inductor"`):
Expand Down Expand Up @@ -1566,7 +1568,7 @@ class CompileConfig:
```
"""

fullgraph: bool = True
fullgraph: bool = False
dynamic: Optional[bool] = None
backend: Union[str, Callable] = "inductor"
mode: str = "reduce-overhead"
Expand Down
25 changes: 7 additions & 18 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@
_prepare_token_type_ids,
)
from .configuration_utils import (
NEED_SETUP_CACHE_CLASSES_MAPPING,
QUANT_BACKEND_CLASSES_MAPPING,
CompileConfig,
STATIC_CACHE_CLASSES_MAPPING,
GenerationConfig,
GenerationMode,
)
Expand Down Expand Up @@ -1826,7 +1825,7 @@ def _get_cache(self, cache_implementation: str, batch_size: int, max_cache_len:
if cache_implementation == "hybrid" and "llama4" in getattr(self.config, "model_type", ""):
cache_implementation = "hybrid_chunked"

cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
cache_cls: Cache = STATIC_CACHE_CLASSES_MAPPING[cache_implementation]
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
)
Expand Down Expand Up @@ -1958,12 +1957,7 @@ def _prepare_cache_for_generation(
else {}
)
if generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._can_compile_fullgraph:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
if generation_config.cache_implementation in STATIC_CACHE_CLASSES_MAPPING:
model_kwargs[cache_name] = self._get_cache(
cache_implementation=generation_config.cache_implementation,
batch_size=max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size,
Expand Down Expand Up @@ -2115,8 +2109,7 @@ def _valid_auto_compile_criteria(self, model_kwargs: dict, generation_config: Ge
using_compilable_cache = (
isinstance(model_kwargs.get("past_key_values"), Cache) and model_kwargs["past_key_values"].is_compileable
)
# TODO @raushan `self._can_compile_fullgraph` can be removed and inferred from model arch (e.g. MoE doesn't support compile)
can_compile = valid_hardware and using_compilable_cache and self._can_compile_fullgraph
can_compile = valid_hardware and using_compilable_cache

# Exception 1: Some quantization methods do not support compilation
if getattr(self, "hf_quantizer", None) is not None:
Expand Down Expand Up @@ -3475,13 +3468,9 @@ def _sample(
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2" and getattr(
model_kwargs.get("past_key_values"), "is_compileable", False
):
if generation_config.compile_config is None:
generation_config.compile_config = CompileConfig(fullgraph=False)
# only raise warning if the user passed an explicit compile-config (otherwise, simply change the default without confusing the user)
elif generation_config.compile_config.fullgraph:
if self.config._attn_implementation == "flash_attention_2":
# only raise warning if the user passed an explicit compile-config
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

_can_compile_fullgraph = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": DeepseekV2DecoderLayer,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/deepseek_v2/modular_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,8 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int):


class DeepseekV2PreTrainedModel(LlamaPreTrainedModel):
_can_compile_fullgraph = False

def _init_weights(self, module):
LlamaPreTrainedModel._init_weights(module)
if isinstance(module, DeepseekV2MoEGate):
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

_can_compile_fullgraph = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": DeepseekV3DecoderLayer,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/deepseek_v3/modular_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int):


class DeepseekV3PreTrainedModel(LlamaPreTrainedModel):
_can_compile_fullgraph = False

def _init_weights(self, module):
LlamaPreTrainedModel._init_weights(module)
if isinstance(module, DeepseekV3TopkRouter):
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,7 @@ class Dots1PreTrainedModel(PreTrainedModel):
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

_can_compile_fullgraph = True
_can_compile_fullgraph = False
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": Dots1DecoderLayer,
Expand Down
15 changes: 12 additions & 3 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,6 +1736,9 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
to verify that the cache length is indeed set correctly and we don't run out of index when slicing the cache.
"""
for model_class in self.all_generative_model_classes:
# Here, we should ideally not skip any model, and test them all. However, some old models cannot correctly
# use a static cache because they don't create the causal masks correctly.
# TODO: cyril -> relax this by adding a `_support_static_cache` attribute
if not model_class._can_compile_fullgraph:
self.skipTest(reason="This model does not support the static cache format")

Expand Down Expand Up @@ -1956,6 +1959,9 @@ def test_generate_with_static_cache(self):
"""
set_model_tester_for_less_flaky_test(self)
for model_class in self.all_generative_model_classes:
# Here, we should ideally not skip any model, and test them all. However, some old models cannot correctly
# use a static cache because they don't create the causal masks correctly.
# TODO: cyril -> relax this by adding a `_support_static_cache` attribute
if not model_class._can_compile_fullgraph:
self.skipTest(reason="This model does not support the static cache format")

Expand Down Expand Up @@ -2050,7 +2056,7 @@ def test_generate_with_quant_cache(self):
@pytest.mark.generate
@pytest.mark.torch_compile_test
@require_torch_greater_or_equal("2.6") # Uses torch.compiler.set_stance
def test_generate_compile_model_forward(self):
def test_generate_compile_model_forward_fullgraph(self):
"""
Tests that `.generate` is compatible with torch.compile, keeping the same results. Also confirms that
`.forward` called from `.generate` sees no graph breaks or recompilations when compiled.
Expand Down Expand Up @@ -2098,7 +2104,7 @@ def test_generate_compile_model_forward(self):
# 3. compilation-specific setup and generation parameterization
torch.compiler.reset() # prevent cached compilation from being used in the test
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
compile_config = CompileConfig(dynamic=False) # Error out on dynamic shapes
compile_config = CompileConfig(fullgraph=True, dynamic=False) # Error out on dynamic shapes
compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)

generation_kwargs = {
Expand Down Expand Up @@ -2174,8 +2180,11 @@ def test_generate_compilation_all_outputs(self):
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
"""
for model_class in self.all_generative_model_classes:
# Here, we should ideally not skip any model, and test them all. However, some old models cannot correctly
# use a static cache because they don't create the causal masks correctly.
# TODO: cyril -> relax this by adding a `_support_static_cache` attribute
if not model_class._can_compile_fullgraph:
self.skipTest("This model doesn't support compilation without graph breaks")
self.skipTest(reason="This model does not support the static cache format")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()
if self.has_attentions:
Expand Down
19 changes: 0 additions & 19 deletions tests/models/deepseek_v2/test_modeling_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,25 +174,6 @@ def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_value
self.assertEqual(layer.keys.shape, expected_key_shape)
self.assertEqual(layer.values.shape, expected_value_shape)

@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
@pytest.mark.torch_compile_test
def test_generate_compilation_all_outputs(self):
pass

@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
@pytest.mark.torch_compile_test
def test_generate_compile_model_forward(self):
pass

@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

@unittest.skip("Deepseek-V2 uses MLA which has a special head dim and is not compatible with StaticCache shape")
@pytest.mark.torch_compile_test
def test_generate_with_static_cache(self):
pass

@unittest.skip("Dynamic control flow in MoE")
@pytest.mark.torch_compile_test
def test_torch_compile_for_training(self):
Expand Down
21 changes: 0 additions & 21 deletions tests/models/deepseek_v3/test_modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,18 +285,6 @@ def test_contrastive_generate_dict_outputs_use_cache(self):
def test_contrastive_generate_low_memory(self):
pass

@unittest.skip(
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_with_static_cache(self):
pass

@unittest.skip(
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

@unittest.skip(
"DeepseekV3 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support."
)
Expand All @@ -307,15 +295,6 @@ def test_generate_continue_from_inputs_embeds(self):
def test_beam_search_generate_dict_outputs_use_cache(self):
pass

@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
def test_generate_compilation_all_outputs(self):
pass

@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
@pytest.mark.torch_compile_test
def test_generate_compile_model_forward(self):
pass

@unittest.skip("Deepseek-V3 uses MLA so it is not compatible with the standard cache format")
def test_greedy_generate_dict_outputs_use_cache(self):
pass
Expand Down
17 changes: 0 additions & 17 deletions tests/models/dots1/test_modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,6 @@ class Dots1ModelTest(CausalLMModelTest, unittest.TestCase):
test_pruning = False
model_tester_class = Dots1ModelTester

@unittest.skip("dots.llm1's moe is not compatible `token_indices, weight_indices = torch.where(mask)`.")
def test_generate_with_static_cache(self):
pass

@unittest.skip("dots.llm1's moe is not compatible `token_indices, weight_indices = torch.where(mask)`.")
def test_generate_compilation_all_outputs(self):
pass

@unittest.skip("dots.llm1's moe is not compatible `token_indices, weight_indices = torch.where(mask)`")
@pytest.mark.torch_compile_test
def test_generate_compile_model_forward(self):
pass

@unittest.skip("dots.llm1's moe is not compatible token_indices, weight_indices = torch.where(mask).")
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
Expand Down
2 changes: 1 addition & 1 deletion tests/models/janus/test_modeling_janus.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def check_training_gradient_checkpointing(self, gradient_checkpointing_kwargs=No

@unittest.skip("There are recompilations in Janus") # TODO (joao, raushan): fix me
@pytest.mark.torch_compile_test
def test_generate_compile_model_forward(self):
def test_generate_compile_model_forward_fullgraph(self):
pass


Expand Down
7 changes: 0 additions & 7 deletions tests/models/paligemma2/test_modeling_paligemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
is_torch_available,
)
from transformers.testing_utils import (
is_flaky,
require_torch,
torch_device,
)
Expand Down Expand Up @@ -317,12 +316,6 @@ def test_contrastive_generate_low_memory(self):
def test_generate_with_static_cache(self):
pass

@pytest.mark.generate
@pytest.mark.torch_compile_test
@is_flaky
def test_generate_compile_model_forward(self):
super().test_generate_compile_model_forward()

@unittest.skip("Paligemma position ids are 1 indexed")
def test_eager_padding_matches_padding_free_with_position_ids(self):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_generate_compilation_all_outputs(self):
reason="Supported only for text-only inputs (otherwise dynamic control flows for multimodal inputs)"
)
@pytest.mark.torch_compile_test
def test_generate_compile_model_forward(self):
def test_generate_compile_model_forward_fullgraph(self):
pass

@parameterized.expand([("random",), ("same",)])
Expand Down
2 changes: 1 addition & 1 deletion tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def test_generate_from_inputs_embeds_with_static_cache(self):
# passing, fix me
@unittest.skip("Cannot handle 4D attention mask")
@pytest.mark.torch_compile_test
def test_generate_compile_model_forward(self):
def test_generate_compile_model_forward_fullgraph(self):
pass

@unittest.skip("Cannot handle 4D attention mask")
Expand Down
2 changes: 1 addition & 1 deletion tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ def test_labels_sequence_max_length_error_after_changing_config(self):
# TODO (joao, eustache): fix me :) The model is not returning a `Cache` by default
@unittest.skip(reason="Whisper's custom generate is not consistent regarding the cache return types")
@pytest.mark.torch_compile_test
def test_generate_compile_model_forward(self):
def test_generate_compile_model_forward_fullgraph(self):
pass

# TODO (joao, eustache): fix me :)
Expand Down