class OCP_MX_Scheme(str, Enum):
w_mxfp4_a_mxfp4 = "w_mxfp4_a_mxfp4"
w_mxfp4_a_mxfp6_e3m2 = "w_mxfp4_a_mxfp6_e3m2"
w_mxfp4_a_mxfp6_e2m3 = "w_mxfp4_a_mxfp6_e2m3"
w_mxfp6_e3m2_a_mxfp6_e3m2 = "w_mxfp6_e3m2_a_mxfp6_e3m2"
w_mxfp6_e2m3_a_mxfp6_e2m3 = "w_mxfp6_e2m3_a_mxfp6_e2m3"
@classmethod
def from_quant_dtype(
cls, input_dtype: Union[str, None], weight_dtype: Union[str, None]
):
if input_dtype not in OCP_MX_DTYPES or weight_dtype not in OCP_MX_DTYPES:
return None
elif input_dtype == "mxfp4" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp4
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e3m2
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp4":
return cls.w_mxfp4_a_mxfp6_e2m3
elif input_dtype == "mxfp6_e3m2" and weight_dtype == "mxfp6_e3m2":
return cls.w_mxfp6_e3m2_a_mxfp6_e3m2
elif input_dtype == "mxfp6_e2m3" and weight_dtype == "mxfp6_e2m3":
return cls.w_mxfp6_e2m3_a_mxfp6_e2m3
else:
logger.warning(
"input_dtype='%s' and"
" weight_dtype='%s' is not supported "
"in OCP_MX_Scheme at the moment.",
input_dtype,
weight_dtype,
)
return None