Skip to content

vllm.model_executor.models.qwen3_omni_moe_thinker

Inference-only Qwen3-Omni-Moe model (thinker part).

Qwen3OmniMoeThinkerDummyInputsBuilder module-attribute

Qwen3OmniMoeThinkerDummyInputsBuilder = (
    Qwen2_5OmniThinkerDummyInputsBuilder
)

logger module-attribute

logger = init_logger(__name__)

Qwen3MoeLLMForCausalLM

Bases: Qwen3MoeForCausalLM

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3MoeLLMForCausalLM(Qwen3MoeForCausalLM):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super(Qwen3MoeForCausalLM, self).__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = Qwen3MoeLLMModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size, config.hidden_size, quant_config=quant_config
        )
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors
        )

config instance-attribute

config = config

lm_head instance-attribute

lm_head = ParallelLMHead(
    vocab_size, hidden_size, quant_config=quant_config
)

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = Qwen3MoeLLMModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super(Qwen3MoeForCausalLM, self).__init__()
    config = vllm_config.model_config.hf_config
    quant_config = vllm_config.quant_config
    self.config = config
    self.quant_config = quant_config
    self.model = Qwen3MoeLLMModel(
        vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
    )
    self.lm_head = ParallelLMHead(
        config.vocab_size, config.hidden_size, quant_config=quant_config
    )
    if self.config.tie_word_embeddings:
        self.lm_head.weight = self.model.embed_tokens.weight
    self.logits_processor = LogitsProcessor(config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors
    )

Qwen3MoeLLMModel

Bases: Qwen3MoeModel

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
        "deepstack_input_embeds": 0,
    }
)
class Qwen3MoeLLMModel(Qwen3MoeModel):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config, prefix=prefix)

        self.deepstack_multiscale_layer_start = 1

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        deepstack_input_embeds: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for layer_idx, layer in enumerate(
            self.layers[self.start_layer : self.end_layer]
        ):
            layer_idx = layer_idx + self.start_layer

            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

            if deepstack_input_embeds is not None and layer_idx in range(
                0, len(deepstack_input_embeds)
            ):
                hidden_states = (
                    hidden_states
                    + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
                )

        if not get_pp_group().is_last_rank:
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

deepstack_multiscale_layer_start instance-attribute

deepstack_multiscale_layer_start = 1

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__(vllm_config=vllm_config, prefix=prefix)

    self.deepstack_multiscale_layer_start = 1

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    deepstack_input_embeds: Optional[
        IntermediateTensors
    ] = None,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            hidden_states = inputs_embeds
        else:
            hidden_states = self.get_input_embeddings(input_ids)
        residual = None
    else:
        assert intermediate_tensors is not None
        hidden_states = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]
    for layer_idx, layer in enumerate(
        self.layers[self.start_layer : self.end_layer]
    ):
        layer_idx = layer_idx + self.start_layer

        hidden_states, residual = layer(
            positions,
            hidden_states,
            residual,
        )

        if deepstack_input_embeds is not None and layer_idx in range(
            0, len(deepstack_input_embeds)
        ):
            hidden_states = (
                hidden_states
                + deepstack_input_embeds[f"deepstack_input_embeds_{layer_idx}"]
            )

    if not get_pp_group().is_last_rank:
        return IntermediateTensors(
            {"hidden_states": hidden_states, "residual": residual}
        )
    hidden_states, _ = self.norm(hidden_states, residual)
    return hidden_states

Qwen3OmniMoeConditionalGenerationMixin

Bases: Qwen2_5OmniConditionalGenerationMixin

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixin):
    def _validate_and_reshape_mm_tensor(
        self, mm_input: object, name: str, dim: int = 0
    ) -> torch.Tensor:
        if not isinstance(mm_input, (torch.Tensor, list)):
            raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
        if name == "feature_attention_mask":
            dim = -1
        if isinstance(mm_input, torch.Tensor):
            return torch.concat(list(mm_input), dim=dim)
        else:
            if isinstance(mm_input[0], list):
                return torch.concat(
                    [torch.concat(mm_input[i], dim=dim) for i in range(len(mm_input))],
                    dim=dim,
                )
            else:
                return torch.concat(mm_input, dim=dim)

    def _get_feat_extract_output_lengths(
        self, input_lengths: torch.Tensor
    ) -> torch.Tensor:
        input_lengths_leave = input_lengths % 100
        feat_lengths = (input_lengths_leave - 1) // 2 + 1
        output_lengths = (
            ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
        )
        return output_lengths, output_lengths

    def _process_audio_input(
        self,
        audio_input: Qwen2AudioFeatureInputs,
        audio_hashes: list[str] = None,
        cached_audio_features: torch.Tensor = None,
    ) -> torch.Tensor:
        input_features = audio_input["input_features"]
        audio_feature_lengths = audio_input["audio_feature_lengths"]

        if input_features.ndim == 3:
            assert input_features.shape[0] == 1
            input_features = input_features.squeeze(0)

        if not isinstance(audio_feature_lengths, torch.Tensor):
            audio_feature_lengths = torch.cat(audio_feature_lengths)
        if audio_feature_lengths.ndim == 2:
            audio_feature_lengths = audio_feature_lengths.reshape(-1)

        audio_feat_lengths, audio_output_lengths = (
            self._get_feat_extract_output_lengths(audio_feature_lengths)
        )

        audio_outputs = self.audio_tower(
            input_features.to(self.audio_tower.dtype),
            feature_lens=audio_feature_lengths,
            aftercnn_lens=audio_feat_lengths,
        )
        audio_features = audio_outputs.last_hidden_state
        return audio_features.split(audio_output_lengths.tolist())

_get_feat_extract_output_lengths

_get_feat_extract_output_lengths(
    input_lengths: Tensor,
) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _get_feat_extract_output_lengths(
    self, input_lengths: torch.Tensor
) -> torch.Tensor:
    input_lengths_leave = input_lengths % 100
    feat_lengths = (input_lengths_leave - 1) // 2 + 1
    output_lengths = (
        ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
    )
    return output_lengths, output_lengths

_process_audio_input

_process_audio_input(
    audio_input: Qwen2AudioFeatureInputs,
    audio_hashes: list[str] = None,
    cached_audio_features: Tensor = None,
) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _process_audio_input(
    self,
    audio_input: Qwen2AudioFeatureInputs,
    audio_hashes: list[str] = None,
    cached_audio_features: torch.Tensor = None,
) -> torch.Tensor:
    input_features = audio_input["input_features"]
    audio_feature_lengths = audio_input["audio_feature_lengths"]

    if input_features.ndim == 3:
        assert input_features.shape[0] == 1
        input_features = input_features.squeeze(0)

    if not isinstance(audio_feature_lengths, torch.Tensor):
        audio_feature_lengths = torch.cat(audio_feature_lengths)
    if audio_feature_lengths.ndim == 2:
        audio_feature_lengths = audio_feature_lengths.reshape(-1)

    audio_feat_lengths, audio_output_lengths = (
        self._get_feat_extract_output_lengths(audio_feature_lengths)
    )

    audio_outputs = self.audio_tower(
        input_features.to(self.audio_tower.dtype),
        feature_lens=audio_feature_lengths,
        aftercnn_lens=audio_feat_lengths,
    )
    audio_features = audio_outputs.last_hidden_state
    return audio_features.split(audio_output_lengths.tolist())

_validate_and_reshape_mm_tensor

_validate_and_reshape_mm_tensor(
    mm_input: object, name: str, dim: int = 0
) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _validate_and_reshape_mm_tensor(
    self, mm_input: object, name: str, dim: int = 0
) -> torch.Tensor:
    if not isinstance(mm_input, (torch.Tensor, list)):
        raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
    if name == "feature_attention_mask":
        dim = -1
    if isinstance(mm_input, torch.Tensor):
        return torch.concat(list(mm_input), dim=dim)
    else:
        if isinstance(mm_input[0], list):
            return torch.concat(
                [torch.concat(mm_input[i], dim=dim) for i in range(len(mm_input))],
                dim=dim,
            )
        else:
            return torch.concat(mm_input, dim=dim)

Qwen3OmniMoeThinkerForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsPP, Qwen3OmniMoeConditionalGenerationMixin

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
@MULTIMODAL_REGISTRY.register_processor(
    Qwen3OmniMoeThinkerMultiModalProcessor,
    info=Qwen3OmniMoeThinkerProcessingInfo,
    dummy_inputs=Qwen3OmniMoeThinkerDummyInputsBuilder,
)
class Qwen3OmniMoeThinkerForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsPP,
    Qwen3OmniMoeConditionalGenerationMixin,
):
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "thinker.lm_head.": "language_model.lm_head.",
            "thinker.model.": "language_model.model.",
            "thinker.": "",
        }
    )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"
        if modality.startswith("audio"):
            return "<|audio_start|><|audio_pad|><|audio_end|>"

        raise ValueError("Only image, video or audio modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        thinker_config: Qwen3OmniMoeThinkerConfig = (
            vllm_config.model_config.hf_config.thinker_config
        )
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = thinker_config
        self.multimodal_config = multimodal_config

        # force "use_flash_attention_2=True" to audio tower to align
        # the results.
        if flash_attn is not None:
            audio_config = thinker_config.audio_config
            audio_config._attn_implementation_autoset = True
            audio_config._attn_implementation = "flash_attention_2"
        else:
            logger.warning(
                "flash_attn is not available, the model may not yield the "
                "exactly same result as the transformers implementation "
                "in the audio tower part."
            )

        self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)

        self.visual = Qwen3Omni_VisionTransformer(
            vision_config=thinker_config.vision_config,
            norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "visual"),
        )
        self.quant_config = quant_config

        self.language_model = Qwen3MoeLLMForCausalLM(
            vllm_config=vllm_config.with_hf_config(
                thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"]
            ),
            prefix=maybe_prefix(prefix, "language_model"),
        )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

        self.use_deepstack = hasattr(
            thinker_config.vision_config, "deepstack_visual_indexes"
        )
        self.deepstack_num_level = (
            len(thinker_config.vision_config.deepstack_visual_indexes)
            if self.use_deepstack
            else 0
        )
        # register buffer for deepstack
        self.deepstack_input_embeds = (
            [
                torch.zeros(
                    vllm_config.scheduler_config.max_num_batched_tokens,
                    thinker_config.text_config.hidden_size,
                )
                for _ in range(self.deepstack_num_level)
            ]
            if self.use_deepstack
            else None
        )
        self.visual_dim = thinker_config.vision_config.out_hidden_size
        self.multiscale_dim = self.visual_dim * self.deepstack_num_level

    def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
        # get deepstack_input_embeds from buffer, and clear the buffer
        return IntermediateTensors(
            {
                f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
                    :num_tokens
                ]
                for idx in range(self.deepstack_num_level)
            }
        )

    def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
        # set deepstack_input_embeds to buffer
        num_tokens = deepstack_input_embeds.size(1)
        if num_tokens > self.deepstack_input_embeds[0].size(0):
            self.deepstack_input_embeds = [
                torch.zeros(
                    num_tokens,
                    self.config.text_config.hidden_size,
                    device=self.deepstack_input_embeds[0].device,
                    dtype=self.deepstack_input_embeds[0].dtype,
                )
                for _ in range(self.deepstack_num_level)
            ]
        for idx in range(self.deepstack_num_level):
            self.deepstack_input_embeds[idx][:num_tokens].copy_(
                deepstack_input_embeds[idx]
            )

    def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
        # clear deepstack_input_embeds in buffer
        if num_tokens > 0:
            for idx in range(self.deepstack_num_level):
                self.deepstack_input_embeds[idx][:num_tokens].zero_()

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
            if (
                input_key in ("input_audio_features")
                and "audio" not in mm_input_by_modality
            ):
                mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
                    **kwargs
                )
        return mm_input_by_modality

    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

    def get_multimodal_embeddings(
        self, **kwargs: object
    ) -> Optional[MultiModalEmbeddings]:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not mm_input_by_modality:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings += vision_embeddings
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                multimodal_embeddings += video_embeddings
            if modality == "audio":
                audio_embeddings = self._process_audio_input(multimodal_input)
                multimodal_embeddings += audio_embeddings
        return multimodal_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.language_model.get_input_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
                "https://github.com/vllm-project/vllm/pull/16229."
            )

        deepstack_input_embeds = None
        # TODO (ywang96): support overlapping modalitiy embeddings so that
        # `use_audio_in_video` will work on V1.
        # split the feat dim to obtain multi-scale visual feature
        if self.visual.deepstack_visual_indexes is not None:
            multiscale_len = len(self.visual.deepstack_visual_indexes)
            multimodal_embeddings_multiscale = []
            for index, embeddings in enumerate(multimodal_embeddings):
                if embeddings.shape[-1] != self.config.text_config.hidden_size:
                    visual_dim = embeddings.shape[-1] // (multiscale_len + 1)
                    main_dim = visual_dim
                    multi_dim = visual_dim * multiscale_len
                    embeddings_main, embeddings_multiscale = torch.split(
                        embeddings, [main_dim, multi_dim], dim=-1
                    )
                    multimodal_embeddings[index] = embeddings_main
                    multimodal_embeddings_multiscale.append(embeddings_multiscale)

            # NOTE: This branch should only be triggered for image/video,
            # but not audio-only inputs
            if len(multimodal_embeddings_multiscale) > 0:
                deepstack_input_embeds = inputs_embeds.new_zeros(
                    inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)
                )
                deepstack_input_embeds = _merge_multimodal_embeddings(
                    inputs_embeds=deepstack_input_embeds,
                    multimodal_embeddings=multimodal_embeddings_multiscale,
                    is_multimodal=is_multimodal,
                )
                deepstack_input_embeds = (
                    deepstack_input_embeds.view(
                        inputs_embeds.shape[0], multiscale_len, visual_dim
                    )
                    .permute(1, 0, 2)
                    .contiguous()
                )
                self._set_deepstack_input_embeds(deepstack_input_embeds)

        inputs_embeds = _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )

        return inputs_embeds

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if intermediate_tensors is not None:
            inputs_embeds = None

        if (
            self.use_deepstack
            and inputs_embeds is not None
            and get_pp_group().is_first_rank
        ):
            deepstack_input_embeds = self._get_deepstack_input_embeds(
                inputs_embeds.size(0)
            )
        else:
            deepstack_input_embeds = None

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
            # args for deepstack
            deepstack_input_embeds=deepstack_input_embeds,
        )

        if inputs_embeds is not None and get_pp_group().is_first_rank:
            self._clear_deepstack_input_embeds(inputs_embeds.size(0))

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=["talker.", "code2wav."],
        )
        loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

        return loaded_weights

audio_tower instance-attribute

audio_tower = Qwen3OmniMoeAudioEncoder(audio_config)

config instance-attribute

config = thinker_config

deepstack_input_embeds instance-attribute

deepstack_input_embeds = (
    [
        (zeros(max_num_batched_tokens, hidden_size))
        for _ in (range(deepstack_num_level))
    ]
    if use_deepstack
    else None
)

deepstack_num_level instance-attribute

deepstack_num_level = (
    len(deepstack_visual_indexes) if use_deepstack else 0
)

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_prefix={
        "thinker.lm_head.": "language_model.lm_head.",
        "thinker.model.": "language_model.model.",
        "thinker.": "",
    }
)

language_model instance-attribute

language_model = Qwen3MoeLLMForCausalLM(
    vllm_config=with_hf_config(
        text_config, architectures=["Qwen3MoeForCausalLM"]
    ),
    prefix=maybe_prefix(prefix, "language_model"),
)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

multimodal_config instance-attribute

multimodal_config = multimodal_config

multiscale_dim instance-attribute

multiscale_dim = visual_dim * deepstack_num_level

quant_config instance-attribute

quant_config = quant_config

use_deepstack instance-attribute

use_deepstack = hasattr(
    vision_config, "deepstack_visual_indexes"
)

visual instance-attribute

visual = Qwen3Omni_VisionTransformer(
    vision_config=vision_config,
    norm_eps=getattr(text_config, "rms_norm_eps", 1e-06),
    quant_config=quant_config,
    prefix=maybe_prefix(prefix, "visual"),
)

visual_dim instance-attribute

visual_dim = out_hidden_size

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
    super().__init__()
    thinker_config: Qwen3OmniMoeThinkerConfig = (
        vllm_config.model_config.hf_config.thinker_config
    )
    quant_config = vllm_config.quant_config
    multimodal_config = vllm_config.model_config.multimodal_config
    self.config = thinker_config
    self.multimodal_config = multimodal_config

    # force "use_flash_attention_2=True" to audio tower to align
    # the results.
    if flash_attn is not None:
        audio_config = thinker_config.audio_config
        audio_config._attn_implementation_autoset = True
        audio_config._attn_implementation = "flash_attention_2"
    else:
        logger.warning(
            "flash_attn is not available, the model may not yield the "
            "exactly same result as the transformers implementation "
            "in the audio tower part."
        )

    self.audio_tower = Qwen3OmniMoeAudioEncoder(thinker_config.audio_config)

    self.visual = Qwen3Omni_VisionTransformer(
        vision_config=thinker_config.vision_config,
        norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
        quant_config=quant_config,
        prefix=maybe_prefix(prefix, "visual"),
    )
    self.quant_config = quant_config

    self.language_model = Qwen3MoeLLMForCausalLM(
        vllm_config=vllm_config.with_hf_config(
            thinker_config.text_config, architectures=["Qwen3MoeForCausalLM"]
        ),
        prefix=maybe_prefix(prefix, "language_model"),
    )

    self.make_empty_intermediate_tensors = (
        self.language_model.make_empty_intermediate_tensors
    )

    self.use_deepstack = hasattr(
        thinker_config.vision_config, "deepstack_visual_indexes"
    )
    self.deepstack_num_level = (
        len(thinker_config.vision_config.deepstack_visual_indexes)
        if self.use_deepstack
        else 0
    )
    # register buffer for deepstack
    self.deepstack_input_embeds = (
        [
            torch.zeros(
                vllm_config.scheduler_config.max_num_batched_tokens,
                thinker_config.text_config.hidden_size,
            )
            for _ in range(self.deepstack_num_level)
        ]
        if self.use_deepstack
        else None
    )
    self.visual_dim = thinker_config.vision_config.out_hidden_size
    self.multiscale_dim = self.visual_dim * self.deepstack_num_level

_clear_deepstack_input_embeds

_clear_deepstack_input_embeds(num_tokens: int) -> None
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _clear_deepstack_input_embeds(self, num_tokens: int) -> None:
    # clear deepstack_input_embeds in buffer
    if num_tokens > 0:
        for idx in range(self.deepstack_num_level):
            self.deepstack_input_embeds[idx][:num_tokens].zero_()

_get_deepstack_input_embeds

_get_deepstack_input_embeds(
    num_tokens: int,
) -> IntermediateTensors
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _get_deepstack_input_embeds(self, num_tokens: int) -> IntermediateTensors:
    # get deepstack_input_embeds from buffer, and clear the buffer
    return IntermediateTensors(
        {
            f"deepstack_input_embeds_{idx}": self.deepstack_input_embeds[idx][
                :num_tokens
            ]
            for idx in range(self.deepstack_num_level)
        }
    )

_parse_and_validate_multimodal_inputs

_parse_and_validate_multimodal_inputs(
    **kwargs: object,
) -> dict
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
    mm_input_by_modality = {}

    # Preserve the order of modalities if there are multiple of them
    # from the order of kwargs.
    for input_key in kwargs:
        if (
            input_key in ("pixel_values", "image_embeds")
            and "image" not in mm_input_by_modality
        ):
            mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                **kwargs
            )
        if (
            input_key in ("pixel_values_videos", "video_embeds")
            and "video" not in mm_input_by_modality
        ):
            mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                **kwargs
            )
        if (
            input_key in ("input_audio_features")
            and "audio" not in mm_input_by_modality
        ):
            mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
                **kwargs
            )
    return mm_input_by_modality

_set_deepstack_input_embeds

_set_deepstack_input_embeds(
    deepstack_input_embeds: Tensor,
) -> None
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _set_deepstack_input_embeds(self, deepstack_input_embeds: torch.Tensor) -> None:
    # set deepstack_input_embeds to buffer
    num_tokens = deepstack_input_embeds.size(1)
    if num_tokens > self.deepstack_input_embeds[0].size(0):
        self.deepstack_input_embeds = [
            torch.zeros(
                num_tokens,
                self.config.text_config.hidden_size,
                device=self.deepstack_input_embeds[0].device,
                dtype=self.deepstack_input_embeds[0].dtype,
            )
            for _ in range(self.deepstack_num_level)
        ]
    for idx in range(self.deepstack_num_level):
        self.deepstack_input_embeds[idx][:num_tokens].copy_(
            deepstack_input_embeds[idx]
        )

compute_logits

compute_logits(hidden_states: Tensor) -> Optional[Tensor]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def compute_logits(
    self,
    hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
    return self.language_model.compute_logits(hidden_states)

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
    **kwargs: object,
) -> Union[Tensor, IntermediateTensors]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
    **kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
    if intermediate_tensors is not None:
        inputs_embeds = None

    if (
        self.use_deepstack
        and inputs_embeds is not None
        and get_pp_group().is_first_rank
    ):
        deepstack_input_embeds = self._get_deepstack_input_embeds(
            inputs_embeds.size(0)
        )
    else:
        deepstack_input_embeds = None

    hidden_states = self.language_model.model(
        input_ids,
        positions,
        intermediate_tensors,
        inputs_embeds=inputs_embeds,
        # args for deepstack
        deepstack_input_embeds=deepstack_input_embeds,
    )

    if inputs_embeds is not None and get_pp_group().is_first_rank:
        self._clear_deepstack_input_embeds(inputs_embeds.size(0))

    return hidden_states

get_input_embeddings

get_input_embeddings(
    input_ids: Tensor,
    multimodal_embeddings: Optional[
        MultiModalEmbeddings
    ] = None,
    *,
    is_multimodal: Optional[Tensor] = None,
    handle_oov_mm_token: bool = False,
) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_input_embeddings(
    self,
    input_ids: torch.Tensor,
    multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
    *,
    is_multimodal: Optional[torch.Tensor] = None,
    handle_oov_mm_token: bool = False,
) -> torch.Tensor:
    inputs_embeds = self._get_text_embeddings(
        input_ids,
        self.language_model.get_input_embeddings,
        is_multimodal=is_multimodal,
        handle_oov_mm_token=handle_oov_mm_token,
    )

    if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
        return inputs_embeds

    if is_multimodal is None:
        raise ValueError(
            "`get_input_embeddings` now requires `is_multimodal` arg, "
            "please update your model runner according to "
            "https://github.com/vllm-project/vllm/pull/16229."
        )

    deepstack_input_embeds = None
    # TODO (ywang96): support overlapping modalitiy embeddings so that
    # `use_audio_in_video` will work on V1.
    # split the feat dim to obtain multi-scale visual feature
    if self.visual.deepstack_visual_indexes is not None:
        multiscale_len = len(self.visual.deepstack_visual_indexes)
        multimodal_embeddings_multiscale = []
        for index, embeddings in enumerate(multimodal_embeddings):
            if embeddings.shape[-1] != self.config.text_config.hidden_size:
                visual_dim = embeddings.shape[-1] // (multiscale_len + 1)
                main_dim = visual_dim
                multi_dim = visual_dim * multiscale_len
                embeddings_main, embeddings_multiscale = torch.split(
                    embeddings, [main_dim, multi_dim], dim=-1
                )
                multimodal_embeddings[index] = embeddings_main
                multimodal_embeddings_multiscale.append(embeddings_multiscale)

        # NOTE: This branch should only be triggered for image/video,
        # but not audio-only inputs
        if len(multimodal_embeddings_multiscale) > 0:
            deepstack_input_embeds = inputs_embeds.new_zeros(
                inputs_embeds.size(0), multiscale_len * inputs_embeds.size(1)
            )
            deepstack_input_embeds = _merge_multimodal_embeddings(
                inputs_embeds=deepstack_input_embeds,
                multimodal_embeddings=multimodal_embeddings_multiscale,
                is_multimodal=is_multimodal,
            )
            deepstack_input_embeds = (
                deepstack_input_embeds.view(
                    inputs_embeds.shape[0], multiscale_len, visual_dim
                )
                .permute(1, 0, 2)
                .contiguous()
            )
            self._set_deepstack_input_embeds(deepstack_input_embeds)

    inputs_embeds = _merge_multimodal_embeddings(
        inputs_embeds=inputs_embeds,
        multimodal_embeddings=multimodal_embeddings,
        is_multimodal=is_multimodal,
    )

    return inputs_embeds

get_language_model

get_language_model() -> Module
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_language_model(self) -> torch.nn.Module:
    return self.language_model

get_multimodal_embeddings

get_multimodal_embeddings(
    **kwargs: object,
) -> Optional[MultiModalEmbeddings]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_multimodal_embeddings(
    self, **kwargs: object
) -> Optional[MultiModalEmbeddings]:
    mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
    if not mm_input_by_modality:
        return []

    # The result multimodal_embeddings is tuple of tensors, with each
    # tensor correspoending to a multimodal data item (image or video).
    multimodal_embeddings: tuple[torch.Tensor, ...] = ()

    # NOTE: It is important to iterate over the keys in this dictionary
    # to preserve the order of the modalities.
    for modality in mm_input_by_modality:
        multimodal_input = mm_input_by_modality[modality]
        if modality == "image":
            vision_embeddings = self._process_image_input(multimodal_input)
            multimodal_embeddings += vision_embeddings
        if modality == "video":
            video_embeddings = self._process_video_input(multimodal_input)
            multimodal_embeddings += video_embeddings
        if modality == "audio":
            audio_embeddings = self._process_audio_input(multimodal_input)
            multimodal_embeddings += audio_embeddings
    return multimodal_embeddings

get_placeholder_str classmethod

get_placeholder_str(modality: str, i: int) -> Optional[str]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
    if modality.startswith("image"):
        return "<|vision_start|><|image_pad|><|vision_end|>"
    if modality.startswith("video"):
        return "<|vision_start|><|video_pad|><|vision_end|>"
    if modality.startswith("audio"):
        return "<|audio_start|><|audio_pad|><|audio_end|>"

    raise ValueError("Only image, video or audio modality is supported")

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(
        self,
        skip_prefixes=["talker.", "code2wav."],
    )
    loaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    return loaded_weights

Qwen3OmniMoeThinkerMultiModalProcessor

Bases: Qwen2_5OmniThinkerMultiModalProcessor

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3OmniMoeThinkerMultiModalProcessor(
    Qwen2_5OmniThinkerMultiModalProcessor,
):
    def _get_feat_extract_output_lengths(
        self, input_lengths: torch.Tensor
    ) -> torch.Tensor:
        input_lengths_leave = input_lengths % 100
        feat_lengths = (input_lengths_leave - 1) // 2 + 1
        output_lengths = (
            ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
        )
        return feat_lengths, output_lengths

    def _maybe_apply_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        prompt_ids: list[int],
        mm_kwargs: MultiModalKwargsItems,
        mm_prompt_updates: MultiModalPromptUpdates,
        is_update_applied: bool,
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
        """
        Qwen3-Omni reimplements this function to handle `use_audio_in_video`.
        """
        mm_item_counts = mm_items.get_all_counts()
        self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

        use_audio_in_video = (
            all(item["use_audio_in_video"].data for item in mm_kwargs["video"])
            if "video" in mm_kwargs
            else False
        )

        if use_audio_in_video and "video" in mm_item_counts:
            assert "audio" in mm_item_counts
            mm_item_counts["audio"] -= mm_item_counts["video"]

        if is_update_applied:
            prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)

        (
            prompt_ids,
            mm_placeholders,
        ) = self._apply_prompt_updates(
            prompt_ids,
            mm_prompt_updates,
        )
        self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

        return prompt_ids, mm_placeholders

    def get_updates_use_audio_in_video(
        self,
        thinker_config: PretrainedConfig,
        audio_len: int,
        video_grid_thw: Union[list[int], torch.Tensor],
        video_second_per_grid_t: float,
    ) -> list[int]:
        shift = 0
        audio_token_id = thinker_config.audio_token_id
        video_token_id = thinker_config.video_token_id
        audio_start_token_id = thinker_config.audio_start_token_id
        audio_end_token_id = thinker_config.audio_end_token_id
        spatial_merge_size = thinker_config.vision_config.spatial_merge_size
        position_id_per_seconds = thinker_config.position_id_per_seconds
        audio_token_indices = np.arange(next(iter([audio_len])))
        curr_video_grid_thw = next(iter([video_grid_thw]))
        height = curr_video_grid_thw[1] // spatial_merge_size
        width = curr_video_grid_thw[2] // spatial_merge_size
        video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
        video_token_indices = np.broadcast_to(
            video_token_indices, (video_token_indices.shape[0], height, width)
        ).reshape(-1)
        video_token_indices = (
            (video_token_indices + shift)
            * next(iter([video_second_per_grid_t]))
            * position_id_per_seconds
        )
        video_data_index, audio_data_index = 0, 0
        updates = [audio_start_token_id]
        while video_data_index < len(video_token_indices) and audio_data_index < len(
            audio_token_indices
        ):
            if (
                video_token_indices[video_data_index]
                <= audio_token_indices[audio_data_index]
            ):
                updates += [video_token_id]
                video_data_index += 1
            else:
                updates += [audio_token_id]
                audio_data_index += 1
        if video_data_index < len(video_token_indices):
            updates += [video_token_id] * (len(video_token_indices) - video_data_index)
        if audio_data_index < len(audio_token_indices):
            updates += [audio_token_id] * (len(audio_token_indices) - audio_data_index)
        updates += [audio_end_token_id]
        return updates

    def _get_prompt_updates(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, Any],
        out_mm_kwargs: MultiModalKwargsItems,
    ) -> Sequence[PromptUpdate]:
        processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
        tokenizer = self.info.get_tokenizer()
        image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
        vocab = tokenizer.get_vocab()

        audio_token = processor.audio_token
        image_token = processor.image_token
        video_token = processor.video_token
        audio_token_id = vocab[audio_token]
        image_token_id = vocab[image_token]
        video_token_id = vocab[video_token]

        out_mm_data = out_mm_kwargs.get_data()
        audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
        feature_attention_mask = out_mm_data.get("feature_attention_mask")
        if audio_feature_lengths is None and feature_attention_mask is None:
            audio_output_lengths = []
        elif audio_feature_lengths is not None:
            _, audio_output_lens = self._get_feat_extract_output_lengths(
                audio_feature_lengths
            )
            audio_output_lengths = audio_output_lens.tolist()
        elif feature_attention_mask is not None:
            assert isinstance(feature_attention_mask, torch.Tensor)
            _, audio_output_lens = self._get_feat_extract_output_lengths(
                feature_attention_mask.sum(-1)
            )
            audio_output_lengths = audio_output_lens.tolist()

        # number of audios read from video.
        audio_in_video_item_idx = 0
        audio_item_idx = 0

        def get_replacement_qwen2_audio(item_idx: int):
            nonlocal audio_item_idx
            item_idx += audio_in_video_item_idx

            audio_item_idx += 1

            num_features = audio_output_lengths[item_idx]
            if num_features == 0:
                audios = mm_items.get_items("audio", AudioProcessorItems)
                audio = audios.get(item_idx)
                raise ValueError(
                    f"The audio {audio} (len={len(audio)}) is too short "
                    "to be represented inside the model"
                )

            return [audio_token_id] * num_features

        def get_replacement_qwen2_vision(item_idx: int, modality: str):
            grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx]
            assert isinstance(grid_thw, torch.Tensor)
            merge_length = image_processor.merge_size**2

            token_id = image_token_id if modality == "image" else video_token_id
            return [token_id] * (int(grid_thw.prod()) // merge_length)

        use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
        thinker_config = self.info.get_hf_config()

        def get_replacement_qwen2_use_audio_in_video(item_idx: int):
            nonlocal audio_in_video_item_idx
            audio_num_features = audio_output_lengths[audio_item_idx + item_idx]
            video_grid_thw = out_mm_data["video_grid_thw"][item_idx]

            audio_in_video_item_idx += 1

            second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None)
            if second_per_grid_ts:
                video_second_per_grid_t = second_per_grid_ts[item_idx]
            else:
                video_second_per_grid_t = 1.0

            return self.get_updates_use_audio_in_video(
                thinker_config=thinker_config,
                audio_len=audio_num_features,
                video_grid_thw=video_grid_thw,
                video_second_per_grid_t=video_second_per_grid_t,
            )

        video_replacement_fn = (
            get_replacement_qwen2_use_audio_in_video
            if use_audio_in_video
            else partial(get_replacement_qwen2_vision, modality="video")
        )

        return [
            PromptReplacement(
                modality="audio",
                target=audio_token,
                replacement=get_replacement_qwen2_audio,
            ),
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=partial(get_replacement_qwen2_vision, modality="image"),
            ),
            PromptReplacement(
                modality="video",
                target=video_token,
                replacement=video_replacement_fn,
            ),
        ]

    def _validate_mm_placeholders(
        self,
        mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
        mm_item_counts: Mapping[str, int],
    ) -> None:
        BaseMultiModalProcessor[
            Qwen2_5OmniThinkerProcessingInfo
        ]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts)

    def _get_raw_input_ids(
        self,
        token_ids: list[int],
        use_audio_in_video: bool = False,
    ) -> list[int]:
        tokenizer = self.info.get_tokenizer()
        vision_bos_token = tokenizer.encode(tokenizer.vision_bos_token)[0]
        vision_eos_token = tokenizer.encode(tokenizer.vision_eos_token)[0]
        audio_bos_token = tokenizer.encode(tokenizer.audio_bos_token)[0]
        audio_eos_token = tokenizer.encode(tokenizer.audio_eos_token)[0]
        audio_token = tokenizer.encode("<|audio_pad|>")[0]
        image_token = tokenizer.encode("<|image_pad|>")[0]
        video_token = tokenizer.encode("<|video_pad|>")[0]

        result = token_ids[:]
        if use_audio_in_video:
            while True:
                start = None
                for i in range(len(result) - 1):
                    if result[i : i + 2] == [vision_bos_token, audio_bos_token]:
                        start = i
                        break
                if start is not None:
                    end = None
                    for i in range(start + 2, len(result) - 1):
                        if result[i : i + 2] == [audio_eos_token, vision_eos_token]:
                            end = i
                            break
                    if end is not None:
                        result = (
                            result[:start]
                            + [vision_bos_token, video_token, vision_eos_token]
                            + result[end + 2 :]
                        )
                else:
                    break

        for mm_token in [audio_token, image_token, video_token]:
            compressed = []
            for x in result:
                if x != mm_token or (not compressed or compressed[-1] != mm_token):
                    compressed.append(x)
            result = compressed

        return result

_get_feat_extract_output_lengths

_get_feat_extract_output_lengths(
    input_lengths: Tensor,
) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _get_feat_extract_output_lengths(
    self, input_lengths: torch.Tensor
) -> torch.Tensor:
    input_lengths_leave = input_lengths % 100
    feat_lengths = (input_lengths_leave - 1) // 2 + 1
    output_lengths = (
        ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13
    )
    return feat_lengths, output_lengths

_get_prompt_updates

_get_prompt_updates(
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, Any],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _get_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    hf_processor_mm_kwargs: Mapping[str, Any],
    out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
    processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
    tokenizer = self.info.get_tokenizer()
    image_processor = self.info.get_image_processor(**hf_processor_mm_kwargs)
    vocab = tokenizer.get_vocab()

    audio_token = processor.audio_token
    image_token = processor.image_token
    video_token = processor.video_token
    audio_token_id = vocab[audio_token]
    image_token_id = vocab[image_token]
    video_token_id = vocab[video_token]

    out_mm_data = out_mm_kwargs.get_data()
    audio_feature_lengths = out_mm_data.get("audio_feature_lengths")
    feature_attention_mask = out_mm_data.get("feature_attention_mask")
    if audio_feature_lengths is None and feature_attention_mask is None:
        audio_output_lengths = []
    elif audio_feature_lengths is not None:
        _, audio_output_lens = self._get_feat_extract_output_lengths(
            audio_feature_lengths
        )
        audio_output_lengths = audio_output_lens.tolist()
    elif feature_attention_mask is not None:
        assert isinstance(feature_attention_mask, torch.Tensor)
        _, audio_output_lens = self._get_feat_extract_output_lengths(
            feature_attention_mask.sum(-1)
        )
        audio_output_lengths = audio_output_lens.tolist()

    # number of audios read from video.
    audio_in_video_item_idx = 0
    audio_item_idx = 0

    def get_replacement_qwen2_audio(item_idx: int):
        nonlocal audio_item_idx
        item_idx += audio_in_video_item_idx

        audio_item_idx += 1

        num_features = audio_output_lengths[item_idx]
        if num_features == 0:
            audios = mm_items.get_items("audio", AudioProcessorItems)
            audio = audios.get(item_idx)
            raise ValueError(
                f"The audio {audio} (len={len(audio)}) is too short "
                "to be represented inside the model"
            )

        return [audio_token_id] * num_features

    def get_replacement_qwen2_vision(item_idx: int, modality: str):
        grid_thw = out_mm_data[f"{modality}_grid_thw"][item_idx]
        assert isinstance(grid_thw, torch.Tensor)
        merge_length = image_processor.merge_size**2

        token_id = image_token_id if modality == "image" else video_token_id
        return [token_id] * (int(grid_thw.prod()) // merge_length)

    use_audio_in_video = hf_processor_mm_kwargs.get("use_audio_in_video", False)
    thinker_config = self.info.get_hf_config()

    def get_replacement_qwen2_use_audio_in_video(item_idx: int):
        nonlocal audio_in_video_item_idx
        audio_num_features = audio_output_lengths[audio_item_idx + item_idx]
        video_grid_thw = out_mm_data["video_grid_thw"][item_idx]

        audio_in_video_item_idx += 1

        second_per_grid_ts = hf_processor_mm_kwargs.get("second_per_grid_ts", None)
        if second_per_grid_ts:
            video_second_per_grid_t = second_per_grid_ts[item_idx]
        else:
            video_second_per_grid_t = 1.0

        return self.get_updates_use_audio_in_video(
            thinker_config=thinker_config,
            audio_len=audio_num_features,
            video_grid_thw=video_grid_thw,
            video_second_per_grid_t=video_second_per_grid_t,
        )

    video_replacement_fn = (
        get_replacement_qwen2_use_audio_in_video
        if use_audio_in_video
        else partial(get_replacement_qwen2_vision, modality="video")
    )

    return [
        PromptReplacement(
            modality="audio",
            target=audio_token,
            replacement=get_replacement_qwen2_audio,
        ),
        PromptReplacement(
            modality="image",
            target=image_token,
            replacement=partial(get_replacement_qwen2_vision, modality="image"),
        ),
        PromptReplacement(
            modality="video",
            target=video_token,
            replacement=video_replacement_fn,
        ),
    ]

_get_raw_input_ids

_get_raw_input_ids(
    token_ids: list[int], use_audio_in_video: bool = False
) -> list[int]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _get_raw_input_ids(
    self,
    token_ids: list[int],
    use_audio_in_video: bool = False,
) -> list[int]:
    tokenizer = self.info.get_tokenizer()
    vision_bos_token = tokenizer.encode(tokenizer.vision_bos_token)[0]
    vision_eos_token = tokenizer.encode(tokenizer.vision_eos_token)[0]
    audio_bos_token = tokenizer.encode(tokenizer.audio_bos_token)[0]
    audio_eos_token = tokenizer.encode(tokenizer.audio_eos_token)[0]
    audio_token = tokenizer.encode("<|audio_pad|>")[0]
    image_token = tokenizer.encode("<|image_pad|>")[0]
    video_token = tokenizer.encode("<|video_pad|>")[0]

    result = token_ids[:]
    if use_audio_in_video:
        while True:
            start = None
            for i in range(len(result) - 1):
                if result[i : i + 2] == [vision_bos_token, audio_bos_token]:
                    start = i
                    break
            if start is not None:
                end = None
                for i in range(start + 2, len(result) - 1):
                    if result[i : i + 2] == [audio_eos_token, vision_eos_token]:
                        end = i
                        break
                if end is not None:
                    result = (
                        result[:start]
                        + [vision_bos_token, video_token, vision_eos_token]
                        + result[end + 2 :]
                    )
            else:
                break

    for mm_token in [audio_token, image_token, video_token]:
        compressed = []
        for x in result:
            if x != mm_token or (not compressed or compressed[-1] != mm_token):
                compressed.append(x)
        result = compressed

    return result

_maybe_apply_prompt_updates

_maybe_apply_prompt_updates(
    mm_items: MultiModalDataItems,
    prompt_ids: list[int],
    mm_kwargs: MultiModalKwargsItems,
    mm_prompt_updates: MultiModalPromptUpdates,
    is_update_applied: bool,
) -> tuple[
    list[int],
    str,
    Mapping[str, list[PlaceholderFeaturesInfo]],
]

Qwen3-Omni reimplements this function to handle use_audio_in_video.

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _maybe_apply_prompt_updates(
    self,
    mm_items: MultiModalDataItems,
    prompt_ids: list[int],
    mm_kwargs: MultiModalKwargsItems,
    mm_prompt_updates: MultiModalPromptUpdates,
    is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
    """
    Qwen3-Omni reimplements this function to handle `use_audio_in_video`.
    """
    mm_item_counts = mm_items.get_all_counts()
    self._validate_mm_kwargs(mm_kwargs, mm_item_counts)

    use_audio_in_video = (
        all(item["use_audio_in_video"].data for item in mm_kwargs["video"])
        if "video" in mm_kwargs
        else False
    )

    if use_audio_in_video and "video" in mm_item_counts:
        assert "audio" in mm_item_counts
        mm_item_counts["audio"] -= mm_item_counts["video"]

    if is_update_applied:
        prompt_ids = self._get_raw_input_ids(prompt_ids, use_audio_in_video)

    (
        prompt_ids,
        mm_placeholders,
    ) = self._apply_prompt_updates(
        prompt_ids,
        mm_prompt_updates,
    )
    self._validate_mm_placeholders(mm_placeholders, mm_item_counts)

    return prompt_ids, mm_placeholders

_validate_mm_placeholders

_validate_mm_placeholders(
    mm_placeholders: Mapping[
        str, list[PlaceholderFeaturesInfo]
    ],
    mm_item_counts: Mapping[str, int],
) -> None
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def _validate_mm_placeholders(
    self,
    mm_placeholders: Mapping[str, list[PlaceholderFeaturesInfo]],
    mm_item_counts: Mapping[str, int],
) -> None:
    BaseMultiModalProcessor[
        Qwen2_5OmniThinkerProcessingInfo
    ]._validate_mm_placeholders(self, mm_placeholders, mm_item_counts)

get_updates_use_audio_in_video

get_updates_use_audio_in_video(
    thinker_config: PretrainedConfig,
    audio_len: int,
    video_grid_thw: Union[list[int], Tensor],
    video_second_per_grid_t: float,
) -> list[int]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_updates_use_audio_in_video(
    self,
    thinker_config: PretrainedConfig,
    audio_len: int,
    video_grid_thw: Union[list[int], torch.Tensor],
    video_second_per_grid_t: float,
) -> list[int]:
    shift = 0
    audio_token_id = thinker_config.audio_token_id
    video_token_id = thinker_config.video_token_id
    audio_start_token_id = thinker_config.audio_start_token_id
    audio_end_token_id = thinker_config.audio_end_token_id
    spatial_merge_size = thinker_config.vision_config.spatial_merge_size
    position_id_per_seconds = thinker_config.position_id_per_seconds
    audio_token_indices = np.arange(next(iter([audio_len])))
    curr_video_grid_thw = next(iter([video_grid_thw]))
    height = curr_video_grid_thw[1] // spatial_merge_size
    width = curr_video_grid_thw[2] // spatial_merge_size
    video_token_indices = np.arange(curr_video_grid_thw[0]).reshape(-1, 1, 1)
    video_token_indices = np.broadcast_to(
        video_token_indices, (video_token_indices.shape[0], height, width)
    ).reshape(-1)
    video_token_indices = (
        (video_token_indices + shift)
        * next(iter([video_second_per_grid_t]))
        * position_id_per_seconds
    )
    video_data_index, audio_data_index = 0, 0
    updates = [audio_start_token_id]
    while video_data_index < len(video_token_indices) and audio_data_index < len(
        audio_token_indices
    ):
        if (
            video_token_indices[video_data_index]
            <= audio_token_indices[audio_data_index]
        ):
            updates += [video_token_id]
            video_data_index += 1
        else:
            updates += [audio_token_id]
            audio_data_index += 1
    if video_data_index < len(video_token_indices):
        updates += [video_token_id] * (len(video_token_indices) - video_data_index)
    if audio_data_index < len(audio_token_indices):
        updates += [audio_token_id] * (len(audio_token_indices) - audio_data_index)
    updates += [audio_end_token_id]
    return updates

Qwen3OmniMoeThinkerProcessingInfo

Bases: Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3OmniMoeThinkerProcessingInfo(
    Qwen2AudioProcessingInfo, Qwen2_5_VLProcessingInfo
):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Qwen3OmniMoeConfig).thinker_config

    def get_hf_processor(self, **kwargs: object) -> Qwen3OmniMoeProcessor:
        processor = self.ctx.get_hf_processor(
            Qwen3OmniMoeProcessor,
            use_fast=kwargs.pop("use_fast", True),
            **kwargs,
        )
        if not hasattr(processor, "audio_token"):
            processor.audio_token = "<|audio_pad|>"
        if not hasattr(processor, "image_token"):
            processor.image_token = "<|image_pad|>"
        if not hasattr(processor, "video_token"):
            processor.video_token = "<|video_pad|>"
        return processor

    def get_feature_extractor(self, **kwargs: object):
        hf_processor = self.get_hf_processor(**kwargs)
        feature_extractor = hf_processor.feature_extractor  # type: ignore
        assert isinstance(feature_extractor, WhisperFeatureExtractor)
        return feature_extractor

    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"audio": None, "image": None, "video": None}

get_feature_extractor

get_feature_extractor(**kwargs: object)
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_feature_extractor(self, **kwargs: object):
    hf_processor = self.get_hf_processor(**kwargs)
    feature_extractor = hf_processor.feature_extractor  # type: ignore
    assert isinstance(feature_extractor, WhisperFeatureExtractor)
    return feature_extractor

get_hf_config

get_hf_config()
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_hf_config(self):
    return self.ctx.get_hf_config(Qwen3OmniMoeConfig).thinker_config

get_hf_processor

get_hf_processor(**kwargs: object) -> Qwen3OmniMoeProcessor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_hf_processor(self, **kwargs: object) -> Qwen3OmniMoeProcessor:
    processor = self.ctx.get_hf_processor(
        Qwen3OmniMoeProcessor,
        use_fast=kwargs.pop("use_fast", True),
        **kwargs,
    )
    if not hasattr(processor, "audio_token"):
        processor.audio_token = "<|audio_pad|>"
    if not hasattr(processor, "image_token"):
        processor.image_token = "<|image_pad|>"
    if not hasattr(processor, "video_token"):
        processor.video_token = "<|video_pad|>"
    return processor

get_supported_mm_limits

get_supported_mm_limits() -> Mapping[str, Optional[int]]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
    return {"audio": None, "image": None, "video": None}

Qwen3Omni_VisionTransformer

Bases: Module

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3Omni_VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config,
        norm_eps: float = 1e-6,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
        self.image_size = vision_config.image_size
        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.spatial_merge_unit = self.spatial_merge_size**2
        self.temporal_patch_size = vision_config.temporal_patch_size
        self.apply_vit_abs_pos_embed = vision_config.apply_vit_abs_pos_embed
        self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes

        self.patch_embed = Qwen3_VisionPatchEmbed(
            patch_size=self.patch_size,
            temporal_patch_size=self.temporal_patch_size,
            in_channels=vision_config.in_channels,
            hidden_size=self.hidden_size,
        )

        # vit pos embeding, TODO: spatial_patch_size vs patch_size
        if self.apply_vit_abs_pos_embed:
            self.pos_embed = nn.Embedding(
                (self.image_size // self.patch_size) ** 2, self.hidden_size
            )
        else:
            self.pos_embed = nn.Parameter(
                torch.empty(
                    [1, (self.image_size // self.patch_size) ** 2, self.hidden_size]
                )
            )

        norm_layer = partial(nn.LayerNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList(
            [
                Qwen3_VisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.intermediate_size,
                    act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(vision_config.depth)
            ]
        )
        self.merger = Qwen3_VisionPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=self.hidden_size,
            norm_layer=norm_layer,
            spatial_merge_size=self.spatial_merge_size,
            quant_config=quant_config,
            prefix=f"{prefix}.merger",
        )
        if self.deepstack_visual_indexes is not None:
            self.merger_list = nn.ModuleList(
                [
                    Qwen3_VisionPatchMerger(
                        d_model=vision_config.out_hidden_size,
                        context_dim=self.hidden_size,
                        spatial_merge_size=self.spatial_merge_size,
                        use_postshuffle_norm=True,
                        norm_layer=norm_layer,
                        quant_config=quant_config,
                        prefix=f"{prefix}.merger_list.{layer_idx}",
                    )
                    for layer_idx in range(len(self.deepstack_visual_indexes))
                ]
            )

        self.attn_backend = get_vit_attn_backend(
            head_size=head_dim, dtype=torch.get_default_dtype()
        )
        if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
            torch.get_default_dtype()
        ):
            self.attn_backend = _Backend.FLASH_ATTN

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
        num_grid_per_side = self.num_grid_per_side
        m_size = self.spatial_merge_size
        hidden_dim = self.pos_embed.embedding_dim

        outputs = []
        for t, h, w in grid_thw:
            h_idxs = torch.linspace(
                0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
            )
            w_idxs = torch.linspace(
                0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
            )

            h_floor = h_idxs.to(torch.long)
            w_floor = w_idxs.to(torch.long)
            h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
            w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

            dh = h_idxs - h_floor
            dw = w_idxs - w_floor

            # Create meshgrid view for all h, w vars
            dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
            h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
            h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
            h_floor_grid_idx = h_floor_grid * num_grid_per_side
            h_ceil_grid_idx = h_ceil_grid * num_grid_per_side

            # original computation of weights
            # w00 = (1 - dh_grid) * (1 - dw_grid)
            # w01 = (1 - dh_grid) * dw_grid
            # w10 = dh_grid * (1 - dw_grid)
            # w11 = dh_grid * dw_grid
            # we reuse w11 here to avoid duplicate
            # dh_grid * dw_grid computation
            w11 = dh_grid * dw_grid
            w10 = dh_grid - w11
            w01 = dw_grid - w11
            w00 = 1 - dh_grid - dw_grid + w11

            idx00 = h_floor_grid_idx + w_floor_grid
            idx01 = h_floor_grid_idx + w_ceil_grid
            idx10 = h_ceil_grid_idx + w_floor_grid
            idx11 = h_ceil_grid_idx + w_ceil_grid

            indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1)
            weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
            weights = weights.to(dtype=self.dtype, device=self.device)

            embeds = self.pos_embed(indices)
            weighted_embeds = embeds * weights
            p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
            combined = p0 + p1 + p2 + p3

            combined = combined.view(h * w, hidden_dim)
            repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
            repeated = repeated.view(
                t, h // m_size, m_size, w // m_size, m_size, hidden_dim
            )
            repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim)
            outputs.append(repeated)

        return torch.cat(outputs, dim=0)

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
    ) -> tuple[Optional[int], Optional[list[int]]]:
        max_seqlen, seqlens = None, None
        if self.attn_backend == _Backend.FLASH_ATTN:
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        elif self.attn_backend == _Backend.XFORMERS:
            seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
        return max_seqlen, seqlens

    def forward(
        self,
        x: torch.Tensor,
        grid_thw: list[list[int]],
    ) -> torch.Tensor:
        hidden_states = x.to(device=self.device, dtype=self.dtype)
        hidden_states = self.patch_embed(hidden_states)

        if self.apply_vit_abs_pos_embed:
            pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
            hidden_states = hidden_states + pos_embeds
        rotary_pos_emb = self.rot_pos_emb(grid_thw)

        cu_seqlens = torch.repeat_interleave(
            grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
        ).cumsum(
            dim=0,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        hidden_states = hidden_states.unsqueeze(1)
        rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
        max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)

        hidden_states_list = []
        deepstack_visual_indexes = self.deepstack_visual_indexes

        for layer_num, blk in enumerate(self.blocks):
            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens,
                rotary_pos_emb=rotary_pos_emb,
                max_seqlen=max_seqlen,
                seqlens=seqlens,
            )
            if (
                deepstack_visual_indexes is not None
                and layer_num in deepstack_visual_indexes
            ):
                hidden_states_list.append(hidden_states)

        hidden_states = self.merger(hidden_states)

        # processing deepstack
        if deepstack_visual_indexes is not None:
            processed_hidden_states_list = [hidden_states]
            for idx, x in enumerate(hidden_states_list):
                x = self.merger_list[idx](x)
                processed_hidden_states_list.append(x)
            # we cat the original visual features and deepstack features
            # along the feature dim
            hidden_states = torch.cat(
                processed_hidden_states_list, dim=1
            )  # [seq_len, hidden_size * (1 + depth_of_deepstack)]

        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

apply_vit_abs_pos_embed instance-attribute

apply_vit_abs_pos_embed = apply_vit_abs_pos_embed

attn_backend instance-attribute

attn_backend = get_vit_attn_backend(
    head_size=head_dim, dtype=get_default_dtype()
)

blocks instance-attribute

blocks = ModuleList(
    [
        (
            Qwen3_VisionBlock(
                dim=hidden_size,
                num_heads=num_heads,
                mlp_hidden_dim=intermediate_size,
                act_fn=_ACTIVATION_REGISTRY[hidden_act],
                norm_layer=norm_layer,
                quant_config=quant_config,
                prefix=f"{prefix}.blocks.{layer_idx}",
            )
        )
        for layer_idx in (range(depth))
    ]
)

deepstack_visual_indexes instance-attribute

deepstack_visual_indexes = deepstack_visual_indexes

device property

device: device

dtype property

dtype: dtype

hidden_size instance-attribute

hidden_size = hidden_size

image_size instance-attribute

image_size = image_size

merger instance-attribute

merger = Qwen3_VisionPatchMerger(
    d_model=out_hidden_size,
    context_dim=hidden_size,
    norm_layer=norm_layer,
    spatial_merge_size=spatial_merge_size,
    quant_config=quant_config,
    prefix=f"{prefix}.merger",
)

merger_list instance-attribute

merger_list = ModuleList(
    [
        (
            Qwen3_VisionPatchMerger(
                d_model=out_hidden_size,
                context_dim=hidden_size,
                spatial_merge_size=spatial_merge_size,
                use_postshuffle_norm=True,
                norm_layer=norm_layer,
                quant_config=quant_config,
                prefix=f"{prefix}.merger_list.{layer_idx}",
            )
        )
        for layer_idx in (
            range(len(deepstack_visual_indexes))
        )
    ]
)

num_heads instance-attribute

num_heads = num_heads

patch_embed instance-attribute

patch_embed = Qwen3_VisionPatchEmbed(
    patch_size=patch_size,
    temporal_patch_size=temporal_patch_size,
    in_channels=in_channels,
    hidden_size=hidden_size,
)

patch_size instance-attribute

patch_size = patch_size

pos_embed instance-attribute

pos_embed = Embedding(
    (image_size // patch_size) ** 2, hidden_size
)

rotary_pos_emb instance-attribute

rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(
    head_dim // 2
)

spatial_merge_size instance-attribute

spatial_merge_size = spatial_merge_size

spatial_merge_unit instance-attribute

spatial_merge_unit = spatial_merge_size ** 2

temporal_patch_size instance-attribute

temporal_patch_size = temporal_patch_size

__init__

__init__(
    vision_config,
    norm_eps: float = 1e-06,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def __init__(
    self,
    vision_config,
    norm_eps: float = 1e-6,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = vision_config.hidden_size
    self.num_heads = vision_config.num_heads
    self.image_size = vision_config.image_size
    self.patch_size = vision_config.patch_size
    self.spatial_merge_size = vision_config.spatial_merge_size
    self.spatial_merge_unit = self.spatial_merge_size**2
    self.temporal_patch_size = vision_config.temporal_patch_size
    self.apply_vit_abs_pos_embed = vision_config.apply_vit_abs_pos_embed
    self.deepstack_visual_indexes = vision_config.deepstack_visual_indexes

    self.patch_embed = Qwen3_VisionPatchEmbed(
        patch_size=self.patch_size,
        temporal_patch_size=self.temporal_patch_size,
        in_channels=vision_config.in_channels,
        hidden_size=self.hidden_size,
    )

    # vit pos embeding, TODO: spatial_patch_size vs patch_size
    if self.apply_vit_abs_pos_embed:
        self.pos_embed = nn.Embedding(
            (self.image_size // self.patch_size) ** 2, self.hidden_size
        )
    else:
        self.pos_embed = nn.Parameter(
            torch.empty(
                [1, (self.image_size // self.patch_size) ** 2, self.hidden_size]
            )
        )

    norm_layer = partial(nn.LayerNorm, eps=norm_eps)
    head_dim = self.hidden_size // self.num_heads
    self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

    self.blocks = nn.ModuleList(
        [
            Qwen3_VisionBlock(
                dim=self.hidden_size,
                num_heads=self.num_heads,
                mlp_hidden_dim=vision_config.intermediate_size,
                act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
                norm_layer=norm_layer,
                quant_config=quant_config,
                prefix=f"{prefix}.blocks.{layer_idx}",
            )
            for layer_idx in range(vision_config.depth)
        ]
    )
    self.merger = Qwen3_VisionPatchMerger(
        d_model=vision_config.out_hidden_size,
        context_dim=self.hidden_size,
        norm_layer=norm_layer,
        spatial_merge_size=self.spatial_merge_size,
        quant_config=quant_config,
        prefix=f"{prefix}.merger",
    )
    if self.deepstack_visual_indexes is not None:
        self.merger_list = nn.ModuleList(
            [
                Qwen3_VisionPatchMerger(
                    d_model=vision_config.out_hidden_size,
                    context_dim=self.hidden_size,
                    spatial_merge_size=self.spatial_merge_size,
                    use_postshuffle_norm=True,
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.merger_list.{layer_idx}",
                )
                for layer_idx in range(len(self.deepstack_visual_indexes))
            ]
        )

    self.attn_backend = get_vit_attn_backend(
        head_size=head_dim, dtype=torch.get_default_dtype()
    )
    if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
        torch.get_default_dtype()
    ):
        self.attn_backend = _Backend.FLASH_ATTN

compute_attn_mask_seqlen

compute_attn_mask_seqlen(
    cu_seqlens: Tensor,
) -> tuple[Optional[int], Optional[list[int]]]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def compute_attn_mask_seqlen(
    self,
    cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
    max_seqlen, seqlens = None, None
    if self.attn_backend == _Backend.FLASH_ATTN:
        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
    elif self.attn_backend == _Backend.XFORMERS:
        seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
    return max_seqlen, seqlens

fast_pos_embed_interpolate

fast_pos_embed_interpolate(
    grid_thw: list[list[int]],
) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def fast_pos_embed_interpolate(self, grid_thw: list[list[int]]) -> torch.Tensor:
    num_grid_per_side = self.num_grid_per_side
    m_size = self.spatial_merge_size
    hidden_dim = self.pos_embed.embedding_dim

    outputs = []
    for t, h, w in grid_thw:
        h_idxs = torch.linspace(
            0, num_grid_per_side - 1, h, dtype=torch.float32, device=self.device
        )
        w_idxs = torch.linspace(
            0, num_grid_per_side - 1, w, dtype=torch.float32, device=self.device
        )

        h_floor = h_idxs.to(torch.long)
        w_floor = w_idxs.to(torch.long)
        h_ceil = torch.clamp(h_floor + 1, max=num_grid_per_side - 1)
        w_ceil = torch.clamp(w_floor + 1, max=num_grid_per_side - 1)

        dh = h_idxs - h_floor
        dw = w_idxs - w_floor

        # Create meshgrid view for all h, w vars
        dh_grid, dw_grid = torch.meshgrid(dh, dw, indexing="ij")
        h_floor_grid, w_floor_grid = torch.meshgrid(h_floor, w_floor, indexing="ij")
        h_ceil_grid, w_ceil_grid = torch.meshgrid(h_ceil, w_ceil, indexing="ij")
        h_floor_grid_idx = h_floor_grid * num_grid_per_side
        h_ceil_grid_idx = h_ceil_grid * num_grid_per_side

        # original computation of weights
        # w00 = (1 - dh_grid) * (1 - dw_grid)
        # w01 = (1 - dh_grid) * dw_grid
        # w10 = dh_grid * (1 - dw_grid)
        # w11 = dh_grid * dw_grid
        # we reuse w11 here to avoid duplicate
        # dh_grid * dw_grid computation
        w11 = dh_grid * dw_grid
        w10 = dh_grid - w11
        w01 = dw_grid - w11
        w00 = 1 - dh_grid - dw_grid + w11

        idx00 = h_floor_grid_idx + w_floor_grid
        idx01 = h_floor_grid_idx + w_ceil_grid
        idx10 = h_ceil_grid_idx + w_floor_grid
        idx11 = h_ceil_grid_idx + w_ceil_grid

        indices = torch.stack([idx00, idx01, idx10, idx11], dim=0).reshape(4, -1)
        weights = torch.stack([w00, w01, w10, w11], dim=0).reshape(4, -1, 1)
        weights = weights.to(dtype=self.dtype, device=self.device)

        embeds = self.pos_embed(indices)
        weighted_embeds = embeds * weights
        p0, p1, p2, p3 = weighted_embeds.unbind(dim=0)
        combined = p0 + p1 + p2 + p3

        combined = combined.view(h * w, hidden_dim)
        repeated = combined.unsqueeze(0).expand(t, -1, -1).contiguous()
        repeated = repeated.view(
            t, h // m_size, m_size, w // m_size, m_size, hidden_dim
        )
        repeated = repeated.permute(0, 1, 3, 2, 4, 5).reshape(-1, hidden_dim)
        outputs.append(repeated)

    return torch.cat(outputs, dim=0)

forward

forward(x: Tensor, grid_thw: list[list[int]]) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def forward(
    self,
    x: torch.Tensor,
    grid_thw: list[list[int]],
) -> torch.Tensor:
    hidden_states = x.to(device=self.device, dtype=self.dtype)
    hidden_states = self.patch_embed(hidden_states)

    if self.apply_vit_abs_pos_embed:
        pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
        hidden_states = hidden_states + pos_embeds
    rotary_pos_emb = self.rot_pos_emb(grid_thw)

    cu_seqlens = torch.repeat_interleave(
        grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
    ).cumsum(
        dim=0,
        dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
    )
    cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

    hidden_states = hidden_states.unsqueeze(1)
    rotary_pos_emb = rotary_pos_emb.to(hidden_states.device)
    max_seqlen, seqlens = self.compute_attn_mask_seqlen(cu_seqlens)

    hidden_states_list = []
    deepstack_visual_indexes = self.deepstack_visual_indexes

    for layer_num, blk in enumerate(self.blocks):
        hidden_states = blk(
            hidden_states,
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )
        if (
            deepstack_visual_indexes is not None
            and layer_num in deepstack_visual_indexes
        ):
            hidden_states_list.append(hidden_states)

    hidden_states = self.merger(hidden_states)

    # processing deepstack
    if deepstack_visual_indexes is not None:
        processed_hidden_states_list = [hidden_states]
        for idx, x in enumerate(hidden_states_list):
            x = self.merger_list[idx](x)
            processed_hidden_states_list.append(x)
        # we cat the original visual features and deepstack features
        # along the feature dim
        hidden_states = torch.cat(
            processed_hidden_states_list, dim=1
        )  # [seq_len, hidden_size * (1 + depth_of_deepstack)]

    return hidden_states

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        ("attn.qkv.", "attn.q.", "q"),
        ("attn.qkv.", "attn.k.", "k"),
        ("attn.qkv.", "attn.v.", "v"),
    ]
    params_dict = dict(self.named_parameters(remove_duplicate=False))
    loaded_params: set[str] = set()

    for name, loaded_weight in weights:
        for param_name, weight_name, shard_id in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)

            param = params_dict[name]
            weight_loader = param.weight_loader
            weight_loader(param, loaded_weight, shard_id)
            break
        else:
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)
        loaded_params.add(name)
    return loaded_params

rot_pos_emb

rot_pos_emb(grid_thw)
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def rot_pos_emb(self, grid_thw):
    pos_ids = []
    for t, h, w in grid_thw:
        hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
        hpos_ids = hpos_ids.reshape(
            h // self.spatial_merge_size,
            self.spatial_merge_size,
            w // self.spatial_merge_size,
            self.spatial_merge_size,
        )
        hpos_ids = hpos_ids.permute(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
        wpos_ids = wpos_ids.reshape(
            h // self.spatial_merge_size,
            self.spatial_merge_size,
            w // self.spatial_merge_size,
            self.spatial_merge_size,
        )
        wpos_ids = wpos_ids.permute(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()
        pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
    pos_ids = torch.cat(pos_ids, dim=0)
    max_grid_size = grid_thw[:, 1:].max()
    rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
    rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
    return rotary_pos_emb

Qwen3_VisionBlock

Bases: Module

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3_VisionBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_hidden_dim: int,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.attn = Qwen2_5_VisionAttention(
            embed_dim=dim,
            num_heads=num_heads,
            projection_size=dim,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
        self.mlp = Qwen3_VisionMLP(
            dim,
            mlp_hidden_dim,
            act_fn=act_fn,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )

    def forward(
        self,
        x: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: torch.Tensor,
        max_seqlen: Optional[int] = None,  # Only used for Flash Attention
        seqlens: Optional[list[int]] = None,  # Only used for xFormers
    ) -> torch.Tensor:
        x = x + self.attn(
            self.norm1(x),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            max_seqlen=max_seqlen,
            seqlens=seqlens,
        )

        x = x + self.mlp(self.norm2(x))
        return x

attn instance-attribute

attn = Qwen2_5_VisionAttention(
    embed_dim=dim,
    num_heads=num_heads,
    projection_size=dim,
    quant_config=quant_config,
    prefix=f"{prefix}.attn",
)

mlp instance-attribute

mlp = Qwen3_VisionMLP(
    dim,
    mlp_hidden_dim,
    act_fn=act_fn,
    bias=True,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

norm1 instance-attribute

norm1 = norm_layer(dim)

norm2 instance-attribute

norm2 = norm_layer(dim)

__init__

__init__(
    dim: int,
    num_heads: int,
    mlp_hidden_dim: int,
    act_fn: Callable[[Tensor], Tensor] = silu,
    norm_layer: Optional[Callable[[int], Module]] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def __init__(
    self,
    dim: int,
    num_heads: int,
    mlp_hidden_dim: int,
    act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
    norm_layer: Optional[Callable[[int], nn.Module]] = None,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    if norm_layer is None:
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
    self.norm1 = norm_layer(dim)
    self.norm2 = norm_layer(dim)
    self.attn = Qwen2_5_VisionAttention(
        embed_dim=dim,
        num_heads=num_heads,
        projection_size=dim,
        quant_config=quant_config,
        prefix=f"{prefix}.attn",
    )
    self.mlp = Qwen3_VisionMLP(
        dim,
        mlp_hidden_dim,
        act_fn=act_fn,
        bias=True,
        quant_config=quant_config,
        prefix=f"{prefix}.mlp",
    )

forward

forward(
    x: Tensor,
    cu_seqlens: Tensor,
    rotary_pos_emb: Tensor,
    max_seqlen: Optional[int] = None,
    seqlens: Optional[list[int]] = None,
) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def forward(
    self,
    x: torch.Tensor,
    cu_seqlens: torch.Tensor,
    rotary_pos_emb: torch.Tensor,
    max_seqlen: Optional[int] = None,  # Only used for Flash Attention
    seqlens: Optional[list[int]] = None,  # Only used for xFormers
) -> torch.Tensor:
    x = x + self.attn(
        self.norm1(x),
        cu_seqlens=cu_seqlens,
        rotary_pos_emb=rotary_pos_emb,
        max_seqlen=max_seqlen,
        seqlens=seqlens,
    )

    x = x + self.mlp(self.norm2(x))
    return x

Qwen3_VisionMLP

Bases: Module

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3_VisionMLP(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        bias: bool = False,
        act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.linear_fc1 = ColumnParallelLinear(
            in_features,
            hidden_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc1",
        )
        self.linear_fc2 = RowParallelLinear(
            hidden_features,
            in_features,
            bias=bias,
            quant_config=quant_config,
            return_bias=False,
            prefix=f"{prefix}.linear_fc2",
        )
        self.act_fn = act_fn

    def forward(self, x: torch.Tensor):
        mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
        return mlp_output

act_fn instance-attribute

act_fn = act_fn

linear_fc1 instance-attribute

linear_fc1 = ColumnParallelLinear(
    in_features,
    hidden_features,
    bias=bias,
    quant_config=quant_config,
    return_bias=False,
    prefix=f"{prefix}.linear_fc1",
)

linear_fc2 instance-attribute

linear_fc2 = RowParallelLinear(
    hidden_features,
    in_features,
    bias=bias,
    quant_config=quant_config,
    return_bias=False,
    prefix=f"{prefix}.linear_fc2",
)

__init__

__init__(
    in_features: int,
    hidden_features: int,
    bias: bool = False,
    act_fn: Callable[[Tensor], Tensor] = silu,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def __init__(
    self,
    in_features: int,
    hidden_features: int,
    bias: bool = False,
    act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.linear_fc1 = ColumnParallelLinear(
        in_features,
        hidden_features,
        bias=bias,
        quant_config=quant_config,
        return_bias=False,
        prefix=f"{prefix}.linear_fc1",
    )
    self.linear_fc2 = RowParallelLinear(
        hidden_features,
        in_features,
        bias=bias,
        quant_config=quant_config,
        return_bias=False,
        prefix=f"{prefix}.linear_fc2",
    )
    self.act_fn = act_fn

forward

forward(x: Tensor)
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def forward(self, x: torch.Tensor):
    mlp_output = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
    return mlp_output

Qwen3_VisionPatchEmbed

Bases: Module

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        hidden_size: int = 1152,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.hidden_size = hidden_size

        kernel_size = (temporal_patch_size, patch_size, patch_size)
        self.proj = nn.Conv3d(
            in_channels,
            hidden_size,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=True,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        L, C = x.shape
        x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
        x = self.proj(x).view(L, self.hidden_size)
        return x

hidden_size instance-attribute

hidden_size = hidden_size

patch_size instance-attribute

patch_size = patch_size

proj instance-attribute

proj = Conv3d(
    in_channels,
    hidden_size,
    kernel_size=kernel_size,
    stride=kernel_size,
    bias=True,
)

temporal_patch_size instance-attribute

temporal_patch_size = temporal_patch_size

__init__

__init__(
    patch_size: int = 14,
    temporal_patch_size: int = 2,
    in_channels: int = 3,
    hidden_size: int = 1152,
) -> None
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def __init__(
    self,
    patch_size: int = 14,
    temporal_patch_size: int = 2,
    in_channels: int = 3,
    hidden_size: int = 1152,
) -> None:
    super().__init__()
    self.patch_size = patch_size
    self.temporal_patch_size = temporal_patch_size
    self.hidden_size = hidden_size

    kernel_size = (temporal_patch_size, patch_size, patch_size)
    self.proj = nn.Conv3d(
        in_channels,
        hidden_size,
        kernel_size=kernel_size,
        stride=kernel_size,
        bias=True,
    )

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    L, C = x.shape
    x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
    x = self.proj(x).view(L, self.hidden_size)
    return x

Qwen3_VisionPatchMerger

Bases: Module

Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
class Qwen3_VisionPatchMerger(nn.Module):
    def __init__(
        self,
        d_model: int,
        context_dim: int,
        norm_layer: Optional[Callable[[int], nn.Module]] = None,
        spatial_merge_size: int = 2,
        use_postshuffle_norm: bool = False,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = context_dim * (spatial_merge_size**2)

        self.use_postshuffle_norm = use_postshuffle_norm
        if self.use_postshuffle_norm:
            context_dim = self.hidden_size

        if norm_layer is None:
            norm_layer = partial(nn.LayerNorm, eps=1e-6)
        self.use_postshuffle_norm = use_postshuffle_norm
        self.ln_q = norm_layer(
            self.hidden_size if use_postshuffle_norm else context_dim
        )
        self.mlp = nn.ModuleList(
            [
                ColumnParallelLinear(
                    self.hidden_size,
                    self.hidden_size,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.0",
                ),
                nn.GELU(),
                RowParallelLinear(
                    self.hidden_size,
                    d_model,
                    bias=True,
                    quant_config=quant_config,
                    prefix=f"{prefix}.mlp.2",
                ),
            ]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_postshuffle_norm:
            x = self.ln_q(x.view(-1, self.hidden_size))
        else:
            x = self.ln_q(x).view(-1, self.hidden_size)

        mlp_fc1, mlp_act, mlp_fc2 = self.mlp
        x_parallel, _ = mlp_fc1(x)
        x_parallel = mlp_act(x_parallel)
        out, _ = mlp_fc2(x_parallel)
        return out

hidden_size instance-attribute

hidden_size = context_dim * spatial_merge_size ** 2

ln_q instance-attribute

ln_q = norm_layer(
    hidden_size if use_postshuffle_norm else context_dim
)

mlp instance-attribute

mlp = ModuleList(
    [
        ColumnParallelLinear(
            hidden_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp.0",
        ),
        GELU(),
        RowParallelLinear(
            hidden_size,
            d_model,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp.2",
        ),
    ]
)

use_postshuffle_norm instance-attribute

use_postshuffle_norm = use_postshuffle_norm

__init__

__init__(
    d_model: int,
    context_dim: int,
    norm_layer: Optional[Callable[[int], Module]] = None,
    spatial_merge_size: int = 2,
    use_postshuffle_norm: bool = False,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def __init__(
    self,
    d_model: int,
    context_dim: int,
    norm_layer: Optional[Callable[[int], nn.Module]] = None,
    spatial_merge_size: int = 2,
    use_postshuffle_norm: bool = False,
    quant_config: Optional[QuantizationConfig] = None,
    prefix: str = "",
) -> None:
    super().__init__()
    self.hidden_size = context_dim * (spatial_merge_size**2)

    self.use_postshuffle_norm = use_postshuffle_norm
    if self.use_postshuffle_norm:
        context_dim = self.hidden_size

    if norm_layer is None:
        norm_layer = partial(nn.LayerNorm, eps=1e-6)
    self.use_postshuffle_norm = use_postshuffle_norm
    self.ln_q = norm_layer(
        self.hidden_size if use_postshuffle_norm else context_dim
    )
    self.mlp = nn.ModuleList(
        [
            ColumnParallelLinear(
                self.hidden_size,
                self.hidden_size,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp.0",
            ),
            nn.GELU(),
            RowParallelLinear(
                self.hidden_size,
                d_model,
                bias=True,
                quant_config=quant_config,
                prefix=f"{prefix}.mlp.2",
            ),
        ]
    )

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/qwen3_omni_moe_thinker.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    if self.use_postshuffle_norm:
        x = self.ln_q(x.view(-1, self.hidden_size))
    else:
        x = self.ln_q(x).view(-1, self.hidden_size)

    mlp_fc1, mlp_act, mlp_fc2 = self.mlp
    x_parallel, _ = mlp_fc1(x)
    x_parallel = mlp_act(x_parallel)
    out, _ = mlp_fc2(x_parallel)
    return out