Skip to content

Conversation

@Datta0
Copy link
Collaborator

@Datta0 Datta0 commented Dec 1, 2025

vllm deprecated and now removed GuidedDecodingParams
This PR aims to patch that to StructuredOutputParams

Please also check
unslothai/unsloth-zoo#367

@Datta0 Datta0 changed the base branch from main to nightly December 1, 2025 10:48
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Datta0, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on improving compatibility with external libraries, primarily vLLM, by adapting to recent API changes. It also includes several minor code quality enhancements, such as better error reporting during patching processes, specific numerical stability adjustments for the Qwen3 MoE model, and small textual corrections across the codebase. The overall goal is to ensure robust integration and maintain code clarity.

Highlights

  • vLLM Compatibility Fix: Updated the vLLM integration to address the deprecation and removal of GuidedDecodingParams, now dynamically aliasing it to StructuredOutputParams to maintain compatibility with trl.
  • Code Refactoring and Error Reporting: Refactored vLLM related patching functions into a single patch_vllm_imports and enhanced error reporting in trl patching by adding print statements for failed imports or class detections.
  • Qwen3 MoE Model Adjustments: Applied specific dtype adjustments for torch_nn_functional_softmax and final_X tensor initialization, and optimized an in-place addition operation in the Qwen3MoeDecoderLayer_fast_forward.
  • Minor Typo Corrections: Corrected a recurring typo from 'Inferene' to 'Inference' in comments across several model files (falcon_h1, gemma, gemma2, llama, mistral, qwen2, qwen3).
  • Improved Statistics Formatting: Enhanced the display of model architecture names in the statistics string within vision.py for better readability, specifically for 'VL' and 'MoE' components.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a breaking change in vllm by patching the deprecated GuidedDecodingParams. The refactoring of vLLM-related fixes into a single function is a good structural improvement. The additional logging in rl.py will be helpful for debugging. I've provided a few suggestions to further improve the code, including correcting a link in a comment and improving exception handling for the new log messages.

@chatgpt-codex-connector
Copy link

💡 Codex Review

_flag_for_generation = self._flag_for_generation,
)
hidden_states += hidden_states

P1 Badge Restore residual add in cached self-attention

When use_cache with _flag_for_generation is set (the standard generation path), the decoder now updates hidden_states with hidden_states += hidden_states instead of adding back the saved residual. This doubles the self-attention output and completely drops the skip connection, so inference with cache will emit incorrect activations. Please add the residual back (as in the uncached branch) rather than adding the tensor to itself.


hidden_states, router_logits = Qwen3MoeSparseMoeBlock_fast_forward(
self.mlp, hidden_states
)
hidden_states += hidden_states

P1 Badge Fix residual after MoE block in cached path

In the cached generation branch, the MoE block output is accumulated with hidden_states += hidden_states, again ignoring the stored residual. This removes the intended skip connection and scales the MoE output by two for any cached generation call, leading to wrong logits. It should mirror the uncached path and add the residual instead.


try:
from vllm.sampling_params import GuidedDecodingParams
except ImportError:
vllm.sampling_params.GuidedDecodingParams = (
vllm.sampling_params.StructuredOutputsParams

P1 Badge Alias GuidedDecodingParams to correct vLLM class

The fallback for removed GuidedDecodingParams assigns it to vllm.sampling_params.StructuredOutputsParams, but the vLLM rename in PR #22772 introduced StructuredOutputParams (singular). On vLLM versions where GuidedDecodingParams is gone, this attribute access will raise an AttributeError and importing Unsloth will fail. Point the alias to the actual renamed class so the patch works on new vLLM releases.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

@Datta0 Datta0 changed the title [FIX] Vllm guided decoding params [WIP][FIX] Vllm guided decoding params Dec 1, 2025
@Datta0 Datta0 changed the title [WIP][FIX] Vllm guided decoding params [FIX] Vllm guided decoding params Dec 1, 2025
except Exception as e:
if UNSLOTH_ENABLE_LOGGING:
print(f"Unsloth: Failed patching vLLM with error = {str(e)}")
text = text.replace(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation issue

@danielhanchen danielhanchen changed the base branch from nightly to main December 1, 2025 13:30
@danielhanchen danielhanchen changed the base branch from main to nightly December 1, 2025 13:32
@danielhanchen danielhanchen merged commit 6c47dc5 into unslothai:nightly Dec 1, 2025
1 check was pending
danielhanchen added a commit that referenced this pull request Dec 1, 2025
@Datta0 Datta0 mentioned this pull request Dec 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants