Skip to content

Conversation

@andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Aug 29, 2025

Summary: Following #2976, which adds support for QAT + LoRA, this PR adds support for QAT during full fine-tuning. See the torchao QAT README for more details.

Current QAT schemes supported are:

fp8-int4, targeting the torch.ops.fbgemm.f8i4bf16_shuffled kernel
fp8-fp8, targeting the torch.ops.fbgemm.f8f8bf16_rowwise kernel

Test Plan: https://gist.github.com/andrewor14/048b5c1bd01b7fa23c53913856a8ef9f

Full fine-tuning Llama3.1-8B with and without QAT on yahma/alpaca-cleaned for 1 epoch:

  • Batch size = 16 (no grad accum)
  • Learning rate = 4e-5
  • Quantization scheme = fp8-int4

Wikitext perplexity:

  • QAT improved perplexity by 19.2% compared to regular fine-tuning
  • QAT's int4 quantized model even outperformed the bf16 baseline
  • Regular int4 quantized model (without QAT) was significantly worse than the bf16 baseline
==> unsloth_model_full_baseline_output/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.8446|±  |   N/A|

==> unsloth_model_full_baseline_output/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |11.4595|±  |   N/A|

==> unsloth_model_full_qat_fp8-int4_output/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |9.2336|±  |   N/A|

Fibonacci test:

  • Both bf16 baseline and int4 quantized models correctly identified 13 as the next number
  • QAT quantized model was more succinct in its response
  • No substantial differences here
### Instruction:
Continue the fibonnaci sequence.

### Input:
1, 1, 2, 3, 5, 8

==> unsloth_model_full_baseline_output/eval_float.log <==
### Response:
The next number in the Fibonacci sequence is 13.<|end_of_text|>

==> unsloth_model_full_baseline_output/eval_quantized.log <==
### Response:
The next number in the Fibonacci sequence is 13.<|end_of_text|>

==> unsloth_model_full_qat_fp8-int4_output/eval_quantized.log <==
### Response:
13<|end_of_text|>

@jerryzh168
Copy link
Contributor

@andrewor14 it's reversed I think, fp8-fp8 is targeting torch.ops.fbgemm.f8f8bf16_rowwise

@andrewor14
Copy link
Contributor Author

@andrewor14 it's reversed I think, fp8-fp8 is targeting torch.ops.fbgemm.f8f8bf16_rowwise

thanks, fixed

Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! Some small changes

@andrewor14
Copy link
Contributor Author

Thanks, just fixed

**Summary:** Following unslothai#2976,
which adds support for QAT + LoRA, this PR adds support for QAT
during full fine-tuning. See the [torchao QAT README](https://github.com/pytorch/ao/blob/main/torchao/quantization/qat/README.md)
for more details.

Current QAT schemes supported are:
```
fp8-int4, targeting the torch.ops.fbgemm.f8i4bf16_shuffled kernel
fp8-fp8, targeting the torch.ops.fbgemm.f8f8bf16_rowwise kernel
```

**Test Plan:** https://gist.github.com/andrewor14/048b5c1bd01b7fa23c53913856a8ef9f

Full fine-tuning Llama3.1-8B with and without QAT on `yahma/alpaca-cleaned` for 1 epoch:
- Batch size = 16 (no grad accum)
- Learning rate = 4e-5
- Quantization scheme = fp8-int4

Wikitext perplexity:
- QAT improved perplexity by 19.2% compared to regular fine-tuning
- QAT's int4 quantized model even outperformed the bf16 baseline
- Regular int4 quantized model (without QAT) was significantly worse than the bf16 baseline

```
==> unsloth_model_full_baseline_output/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.8446|±  |   N/A|

==> unsloth_model_full_baseline_output/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |11.4595|±  |   N/A|

==> unsloth_model_full_qat_fp8-int4_output/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |9.2336|±  |   N/A|
```

Fibonacci test:
- Both bf16 baseline and int4 quantized models correctly identified 13 as the next number
- QAT quantized model was more succinct in its response
- No substantial differences here

```
### Instruction:
Continue the fibonnaci sequence.

### Input:
1, 1, 2, 3, 5, 8

==> unsloth_model_full_baseline_output/eval_float.log <==
### Response:
The next number in the Fibonacci sequence is 13.<|end_of_text|>

==> unsloth_model_full_baseline_output/eval_quantized.log <==
### Response:
The next number in the Fibonacci sequence is 13.<|end_of_text|>

==> unsloth_model_full_qat_fp8-int4_output/eval_quantized.log <==
### Response:
13<|end_of_text|>
```
@danielhanchen danielhanchen merged commit 3541f6e into unslothai:main Sep 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants