Add TorchAO quantization tests with FP16 models and serialization workarounds #3269
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add TorchAO Quantization Testing with Required Workarounds
Overview
This PR adds comprehensive testing for TorchAO quantization functionality, including save/load operations and inference testing. Due to several compatibility issues between TorchAO, PyTorch 2.7+, and existing quantization methods, several workarounds were necessary.
Changes
test_save_and_inference_torchao()function to test TorchAO quantization end-to-end. use-sflag to show model output.fp16_model_tokenizerfor TorchAO-specific testingRequired Workarounds and Their Reasons
1. FP16 Model Loading Required
Issue: TorchAO quantization cannot be applied to models that are already quantized with BitsAndBytes (4-bit/8-bit).
Error encountered:
Solution: Created separate
fp16_model_tokenizerfixture that loads models withload_in_4bit=Falseto ensure base model is in FP16/BF16 format before applying TorchAO quantization.Code:
2. Safe Globals Context Manager Required
Issue: PyTorch 2.6+ changed the default
weights_onlyparameter intorch.load()fromFalsetoTruefor security reasons. TorchAO quantized models contain custom serialized objects (includinggetattroperations) that are blocked by this security feature.Error encountered:
Solution: Use
torch.serialization.safe_globals([getattr])context manager when loading TorchAO quantized models. This is the recommended approach from PyTorch documentation and TorchAO documentation.Code:
Reference: This is documented in both PyTorch serialization docs and TorchAO integration guide.
3. Cache Disabled for Generation
Issue: TorchAO quantized models use
StaticLayerobjects that don't implement themax_batch_sizeproperty that seem to be required by PyTorch's KV cache system.Error encountered:
Solution: Disable KV cache during generation with
use_cache=False. This is a known limitation when using TorchAO quantization with transformers' generation utilities.Code:
Testing
The tests verify: