Skip to content

Conversation

@danielhanchen
Copy link
Contributor

No description provided.

andrewor14 and others added 30 commits November 19, 2025 23:51
* Enable FP8 + RL training for bf16 models

**Summary:** Enable FP8 + RL training using TorchAO for 1.33x faster training and 42% less model memory usage:
- We quantize the frozen LoRA weights into fp8 and keep the LoRA adapters in bf16
- We leverage TorchAO's `Float8Tensor`, which calls into fbgemm's fp8 x fp8 rowwise matmul kernel
- For now, we need to do an offline quantization first, because vllm doesn't support on-the-fly quantization for torchao yet  (this is in progress: vllm-project/vllm#26327)

**Example usage:**
```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = True,  # set this to True
)

\# the rest is the same as before
model = FastLanguageModel.get_peft_model(...)
```

**Initial results:**
```
\# fp8
{'train_runtime': 1725.4337, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.00015715716748673002, 'epoch': 0.01}

\# bf16
{'train_runtime': 2297.8145, 'train_samples_per_second': 0.174, 'train_steps_per_second': 0.044, 'train_loss': 0.00016081033063528594, 'epoch': 0.01}
```

<img width="1199" height="448" alt="Screenshot 2025-11-11 at 4 10 50 PM" src="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/user-attachments/assets/b6304afd-89e9-42b1-8064-775807e17b23" />

Test script: https://gist.github.com/andrewor14/5b85119fae46845d07b608d420907423

**Requires:**
- pytorch/ao#3158 (torchao nightly or 0.15.0+)
- unslothai/unsloth-zoo#351

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* _get_inference_mode_context_manager

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* Update utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* make loading gpt-oss-BF16 faster. Linked to unsloth-zoo PR #314

* fix model loading and clean merged model directory

* revert default quant

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert mapper.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Add 128x128 PerBlock FP8 + RL

**Summary:** Following #3440,
this PR extends torchao FP8 + RL support to also handle 128x128
PerBlock granularity (in addition to PerRow).

**Example usage:**

```
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-8B-Base",
    max_seq_length = 2048,
    load_in_4bit = False,
    fast_inference = True,
    max_lora_rank = 32,
    load_in_fp8 = "block",  # or "row" or True
)
```

**Initial results:** TBD

**Note:**
- Requires pytorch/ao#3370

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
danielhanchen and others added 9 commits November 30, 2025 23:37
* vllm sampling params fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* do not patch base_trainer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* seperate vllm fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestion from @danielhanchen

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit fbb98c5.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit c64d5b4.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit c156545.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* vllm sampling params fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* do not patch base_trainer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* seperate vllm fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixup deletion

* Fix indentation

* revert to old style

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request encompasses a series of maintenance and compatibility updates across the codebase. It ensures that the project remains aligned with external library changes, particularly concerning vLLM and trl, by introducing specific patches. Additionally, it includes minor code refinements, such as typo corrections and improved debugging output, to enhance overall stability and developer experience. The project's internal version has also been incremented to reflect these changes.

Highlights

  • Dependency Update: The unsloth_zoo dependency has been updated to version 2025.11.6 in pyproject.toml for both huggingface and colab-new configurations.
  • vLLM Compatibility Fix: A new fix, fix_vllm_guided_decoding_params, has been introduced to address a breaking change in vLLM where GuidedDecodingParams was renamed to StructuredOutputsParams. This ensures compatibility with trl which still expects the old name.
  • Version Increment: The internal __version__ string in unsloth/models/_utils.py has been updated to 2025.11.5.
  • Typo Correction: A minor typo from 'Inferene' to 'Inference' has been corrected in comments across multiple model files (Falcon, Gemma, Llama, Mistral, Qwen2, Qwen3, Qwen3-MoE) related to KV Cache and CUDAGraphing.
  • Qwen3-MoE Model Enhancements: Modifications were made to Qwen3MoeSparseMoeBlock_fast_forward to explicitly cast routing_weights and final_X to torch.float32 for improved precision, and residual additions in Qwen3MoeDecoderLayer_fast_forward were simplified from hidden_states = residual + hidden_states to hidden_states += hidden_states.
  • TRL Trainer Patching Improvements: Enhanced debugging messages have been added to unsloth/models/rl.py for trl trainer patching failures. The dtype retrieval for input embeddings was refined, patch_module's overwrite parameter was set to False, and base_trainer is now explicitly excluded from patching.
  • Console Output Formatting: The statistics printout in unsloth/models/vision.py now includes improved formatting for model architecture names, ensuring consistent capitalization (e.g., _VL_, _MoE).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request delivers a set of maintenance and compatibility enhancements across the codebase. Key updates include bumping a core dependency version, implementing a crucial fix for vLLM integration with trl to ensure proper guided decoding, and refining the RL trainer patching mechanism with better error reporting and more precise type handling. Additionally, minor textual corrections and internal version updates contribute to overall code quality and stability.

Highlights

  • Dependency Updates: The unsloth_zoo dependency has been updated to >=2025.11.6 in pyproject.toml for both huggingface and colab-new configurations.
  • vLLM Compatibility Fix: A new fix, fix_vllm_guided_decoding_params, was introduced to address a renaming issue in vLLM where GuidedDecodingParams became StructuredOutputsParams, ensuring compatibility with trl.
  • Version Bump: The internal version of the library has been incremented from 2025.11.4 to 2025.11.5 in unsloth/models/_utils.py.
  • Typo Corrections: A consistent typo, 'Inferene', was corrected to 'Inference' in comments across several model files (falcon_h1, gemma, gemma2, llama, mistral, qwen2, qwen3, qwen3_moe).
  • Qwen3 MoE Enhancements: Explicit torch.float32 dtype casting was added for torch_nn_functional_softmax and final_X initialization, and in-place addition hidden_states += hidden_states was adopted for efficiency in unsloth/models/qwen3_moe.py.
  • RL Trainer Patching Improvements: Enhanced debugging output was added for trl.trainer imports and class lookups, dtype inference was refined to use model.get_input_embeddings().weight.dtype, overwrite was set to False for patch_code, and base_trainer was explicitly excluded from patching in unsloth/models/rl.py.
  • Statistics Output Formatting: The display of arch_name in the statistics output was improved by correctly capitalizing 'VL' and 'MoE' in unsloth/models/vision.py.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

_flag_for_generation = self._flag_for_generation,
)
hidden_states = residual + hidden_states
hidden_states += hidden_states

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve residual connection in Qwen3Moe attention path

When use_cache is enabled, Qwen3MoeDecoderLayer_fast_forward now does hidden_states += hidden_states instead of adding the saved residual. This doubles the self-attention output and removes the skip connection, so cached generation will produce incorrect activations compared to training/eager mode. It should add the stored residual back in this spot.

Useful? React with 👍 / 👎.

self.mlp, hidden_states
)
hidden_states = residual + hidden_states
hidden_states += hidden_states

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Restore MoE residual addition in generation branch

In the same use_cache branch, the MoE block now applies hidden_states += hidden_states instead of combining with the pre-MLP residual. This removes the skip connection and doubles the MoE activations during cached generation, yielding different outputs from the non-cached path. It should add the stored residual instead of self-adding.

Useful? React with 👍 / 👎.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request includes several updates, likely part of a nightly build. It bumps the unsloth_zoo dependency version, fixes some typos in comments, and adds a patch for vLLM compatibility. However, I've found two critical bugs in unsloth/models/qwen3_moe.py where residual connections are incorrectly implemented, causing the hidden states to be doubled instead of having the residual added. These need to be fixed to ensure correct model behavior.

_flag_for_generation = self._flag_for_generation,
)
hidden_states = residual + hidden_states
hidden_states += hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line seems to have a typo. It should be hidden_states += residual to add the residual connection back. Currently, it's hidden_states += hidden_states, which doubles the hidden_states tensor. This is likely not the intended behavior for a residual connection.

Suggested change
hidden_states += hidden_states
hidden_states += residual

self.mlp, hidden_states
)
hidden_states = residual + hidden_states
hidden_states += hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the previous comment, this line appears to be incorrect for a residual connection. It should probably be hidden_states += residual instead of hidden_states += hidden_states to correctly add the residual.

Suggested change
hidden_states += hidden_states
hidden_states += residual

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request delivers a set of maintenance and compatibility updates across the Unsloth library. It ensures that dependencies are current, addresses a specific compatibility challenge with vLLM and trl by introducing a targeted patch, and includes minor code refinements such as typo corrections and improved error handling in the RL training module. The overall version of the library is also updated.

Highlights

  • Dependency Update: The unsloth_zoo dependency has been updated to >=2025.11.6 in the huggingface and colab-new extra requirements within pyproject.toml.
  • vLLM Compatibility Patch: A new fix, fix_vllm_guided_decoding_params, was introduced to address a renaming issue in vLLM where GuidedDecodingParams became StructuredOutputsParams, ensuring compatibility with trl.
  • Version Bump: The internal version of Unsloth has been incremented from 2025.11.4 to 2025.11.5.
  • Typo Corrections: A consistent typo 'Inferene' was corrected to 'Inference' in comments across multiple model files (falcon_h1, gemma, gemma2, llama, mistral, qwen2, qwen3) related to KV Cache and CUDAGraphing.
  • Qwen3 MoE Model Enhancements: Changes were made to Qwen3MoeSparseMoeBlock_fast_forward to explicitly use torch.float32 for softmax and tensor initialization, and hidden_states = residual + hidden_states was optimized to hidden_states += hidden_states in Qwen3MoeDecoderLayer_fast_forward.
  • Improved RL Trainer Patching: Error logging was enhanced for trl.trainer patching, dtype retrieval was refined to use model.get_input_embeddings().weight.dtype, the overwrite parameter was set to False, and base_trainer is now explicitly excluded from patching.
  • Statistics Output Formatting: The display of arch_name in the statistics output for vision models was improved by standardizing casing for _Vl_ and _Moe.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request primarily focuses on version updates, minor bug fixes, and logging improvements across various model implementations. A new vLLM compatibility patch has been added to import_fixes.py to handle GuidedDecodingParams renaming. Several model files (falcon_h1.py, gemma.py, gemma2.py, llama.py, mistral.py, qwen2.py, qwen3.py, qwen3_moe.py) include fixes for a typo in comments. The pyproject.toml file has been updated to reflect a new unsloth_zoo version. Logging in unsloth/models/rl.py has been enhanced to provide more detailed messages during trainer patching. A critical bug was identified in unsloth/models/qwen3_moe.py where residual connections are incorrectly applied.

_flag_for_generation = self._flag_for_generation,
)
hidden_states = residual + hidden_states
hidden_states += hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line appears to be a logical error. hidden_states += hidden_states will double the hidden_states value, effectively computing 2 * hidden_states. Based on the previous line residual = hidden_states, it seems the intention was to add the residual connection, i.e., hidden_states = residual + hidden_states or hidden_states += residual.

Suggested change
hidden_states += hidden_states
hidden_states += residual

self.mlp, hidden_states
)
hidden_states = residual + hidden_states
hidden_states += hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This line appears to be a logical error. hidden_states += hidden_states will double the hidden_states value, effectively computing 2 * hidden_states. Based on the previous line residual = hidden_states, it seems the intention was to add the residual connection, i.e., hidden_states = residual + hidden_states or hidden_states += residual.

Suggested change
hidden_states += hidden_states
hidden_states += residual

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request includes nightly updates, primarily version bumps and a new fix for vLLM guided decoding. While most changes are improvements, I've identified two critical bugs in unsloth/models/qwen3_moe.py related to incorrect residual connections. These have been commented on with suggested fixes. Other changes include better debugging logs and minor code refinements.

_flag_for_generation = self._flag_for_generation,
)
hidden_states = residual + hidden_states
hidden_states += hidden_states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This appears to be a bug in the residual connection. The code was changed from hidden_states = residual + hidden_states to hidden_states += hidden_states, which is equivalent to hidden_states = 2 * hidden_states. This doubles the attention output instead of adding the residual. It should be hidden_states += residual.

Suggested change
hidden_states += hidden_states
hidden_states += residual

danielhanchen and others added 9 commits December 1, 2025 07:19
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@danielhanchen danielhanchen merged commit 1679dde into main Dec 1, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants