Skip to content

Conversation

@rolandtannous
Copy link
Collaborator

Add TorchAO Quantization Testing with Required Workarounds

Overview

This PR adds comprehensive testing for TorchAO quantization functionality, including save/load operations and inference testing. Due to several compatibility issues between TorchAO, PyTorch 2.7+, and existing quantization methods, several workarounds were necessary.

Changes

  • Added test_save_and_inference_torchao() function to test TorchAO quantization end-to-end. use -s flag to show model output.
  • Added separate fixture fp16_model_tokenizer for TorchAO-specific testing
  • Implemented workarounds for known compatibility issues

Required Workarounds and Their Reasons

1. FP16 Model Loading Required

Issue: TorchAO quantization cannot be applied to models that are already quantized with BitsAndBytes (4-bit/8-bit).

Error encountered:

TypeError: Cannot apply TorchAO quantization to already quantized model

Solution: Created separate fp16_model_tokenizer fixture that loads models with load_in_4bit=False to ensure base model is in FP16/BF16 format before applying TorchAO quantization.

Code:

model, tokenizer = FastModel.from_pretrained(
    model_name,
    load_in_4bit=False,  # Essential: No BnB quantization for TorchAO
)

2. Safe Globals Context Manager Required

Issue: PyTorch 2.6+ changed the default weights_only parameter in torch.load() from False to True for security reasons. TorchAO quantized models contain custom serialized objects (including getattr operations) that are blocked by this security feature.

Error encountered:

_pickle.UnpicklingError: Weights only load failed. 
WeightsUnpickler error: Unsupported global: GLOBAL getattr was not an allowed global by default. 
Please use `torch.serialization.add_safe_globals([getattr])` or the 
`torch.serialization.safe_globals([getattr])` context manager to allowlist this global

Solution: Use torch.serialization.safe_globals([getattr]) context manager when loading TorchAO quantized models. This is the recommended approach from PyTorch documentation and TorchAO documentation.

Code:

import torch.serialization
with torch.serialization.safe_globals([getattr]):
    loaded_model, loaded_tokenizer = FastModel.from_pretrained(torchao_save_path, ...)

Reference: This is documented in both PyTorch serialization docs and TorchAO integration guide.

3. Cache Disabled for Generation

Issue: TorchAO quantized models use StaticLayer objects that don't implement the max_batch_size property that seem to be required by PyTorch's KV cache system.

Error encountered:

AttributeError: 'StaticLayer' object has no attribute 'max_batch_size'
    values = [layer.max_batch_size for layer in self.layers]
             ^^^^^^^^^^^^^^^^^^^^

Solution: Disable KV cache during generation with use_cache=False. This is a known limitation when using TorchAO quantization with transformers' generation utilities.

Code:

outputs = loaded_model.generate(
    input_ids=inputs,
    use_cache=False,  # Required: StaticLayer compatibility issue
    max_new_tokens=64,
    # ... other params
)

Testing

The tests verify:

  1. TorchAO models can be saved successfully
  2. TorchAO models can be loaded with proper workarounds
  3. TorchAO models can perform inference and generate text
  4. Generated text is coherent and non-empty
  5. File size reduction is achieved compared to FP16 models

@danielhanchen danielhanchen merged commit 551bc65 into unslothai:main Sep 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants