Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/transformers/integrations/ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@
"attention.sliding_window": "sliding_window",
"vocab_size": "vocab_size",
},
"deci": {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.dimension_count": None,
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
},
}

GGUF_TOKENIZER_MAPPING = {
Expand Down Expand Up @@ -716,6 +728,8 @@ def converted(self) -> Tokenizer:
"nemotron": GGUFGPTConverter,
"gemma2": GGUFGemmaConverter,
"gemma3_text": GGUFGemmaConverter,
"deci": GGUFLlamaConverter,
"decilm": GGUFLlamaConverter,
}


Expand Down
83 changes: 83 additions & 0 deletions tests/quantization/ggml/test_ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ class GgufModelTests(unittest.TestCase):
q4_0_gemma3_qat_model_id = "gemma-3-1b-it-q4_0.gguf"
bf16_gemma3_text_model_id = "gemma-3-1b-it-BF16.gguf"
bf16_gemma3_vision_model_id = "gemma-3-4b-it-BF16.gguf"
deci_original_model_id = "Deci/DeciLM-7B"
deci_model_id = "Deci/DeciLM-7B-instruct-GGUF"
q8_0_deci_model_id = "decilm-7b-uniform-gqa-q8_0.gguf"
fp16_deci_model_id = "decilm-7b-uniform-gqa-f16.gguf"
q8_0_qwen3_model_id = "Qwen3-0.6B-Q8_0.gguf"
q4_k_m_qwen3moe_model_id = "Qwen3-30B-A3B-Q4_K_M.gguf"

Expand Down Expand Up @@ -960,6 +964,85 @@ def test_gemma3_vision_weights_conversion_bf16(self):
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_deci_q8_0(self):
"""Test Deci model loading and inference with Q4_0 quantization."""
tokenizer = AutoTokenizer.from_pretrained(self.deci_model_id, gguf_file=self.q8_0_deci_model_id)
model = AutoModelForCausalLM.from_pretrained(
self.deci_model_id,
gguf_file=self.q8_0_deci_model_id,
device_map="auto",
torch_dtype=torch.float16,
)

text = tokenizer(self.example_text, return_tensors="pt").to(torch_device)
out = model.generate(**text, max_new_tokens=10)

EXPECTED_TEXT = "Hello, I am a language model developed"
self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT)

def test_deci_weights_conversion_fp16(self):
"""Test that GGUF Deci model weights match the original model weights."""
original_model_id = "Deci/DeciLM-7B"
original_model = AutoModelForCausalLM.from_pretrained(
original_model_id,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
)
# You need to have an FP16 version of your GGUF model for accurate comparison

converted_model = AutoModelForCausalLM.from_pretrained(
self.deci_model_id,
gguf_file=self.fp16_deci_model_id,
torch_dtype=torch.float16,
device_map="auto",
)

converted_state_dict = converted_model.state_dict()
original_state_dict = original_model.state_dict()

for layer_name, original_params in original_state_dict.items():
if layer_name in converted_state_dict:
self.assertTrue(original_params.shape == converted_state_dict[layer_name].shape)
torch.testing.assert_close(original_params, converted_state_dict[layer_name])
else:
raise ValueError(f"Layer {layer_name} is not presented in GGUF model")

def test_deci_config_mapping(self):
"""Test that Deci GGUF config mapping is correctly applied."""
from transformers.integrations.ggml import GGUF_CONFIG_MAPPING

self.assertIn("deci", GGUF_CONFIG_MAPPING)

deci_mapping = GGUF_CONFIG_MAPPING["deci"]

expected_mappings = {
"context_length": "max_position_embeddings",
"block_count": "num_hidden_layers",
"feed_forward_length": "intermediate_size",
"embedding_length": "hidden_size",
"rope.freq_base": "rope_theta",
"attention.head_count": "num_attention_heads",
"attention.head_count_kv": "num_key_value_heads",
"attention.layer_norm_rms_epsilon": "rms_norm_eps",
"vocab_size": "vocab_size",
}

for gguf_key, transformers_key in expected_mappings.items():
self.assertEqual(deci_mapping[gguf_key], transformers_key)

self.assertIsNone(deci_mapping["rope.dimension_count"])

def test_deci_architecture_mapping(self):
"""Test that Deci architectures are mapped to GGUFLlamaConverter."""
from transformers.integrations.ggml import GGUF_TO_FAST_CONVERTERS, GGUFLlamaConverter

self.assertIn("deci", GGUF_TO_FAST_CONVERTERS)
self.assertIn("decilm", GGUF_TO_FAST_CONVERTERS)

self.assertEqual(GGUF_TO_FAST_CONVERTERS["deci"], GGUFLlamaConverter)
self.assertEqual(GGUF_TO_FAST_CONVERTERS["decilm"], GGUFLlamaConverter)

@require_read_token
@unittest.skipUnless(is_gguf_available("0.16.0"), "test requires gguf version >= 0.16.0")
def test_qwen3_q8_0(self):
Expand Down