Update README.md
Browse files
README.md
CHANGED
@@ -97,7 +97,7 @@ from transformers import (
|
|
97 |
from torchao.quantization.quant_api import (
|
98 |
IntxWeightOnlyConfig,
|
99 |
Int8DynamicActivationIntxWeightConfig,
|
100 |
-
|
101 |
quantize_,
|
102 |
)
|
103 |
from torchao.quantization.granularity import PerGroup, PerAxis
|
@@ -119,7 +119,7 @@ linear_config = Int8DynamicActivationIntxWeightConfig(
|
|
119 |
weight_granularity=PerGroup(32),
|
120 |
weight_scale_dtype=torch.bfloat16,
|
121 |
)
|
122 |
-
quant_config =
|
123 |
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])
|
124 |
|
125 |
# either use `untied_model_id` or `untied_model_local_path`
|
|
|
97 |
from torchao.quantization.quant_api import (
|
98 |
IntxWeightOnlyConfig,
|
99 |
Int8DynamicActivationIntxWeightConfig,
|
100 |
+
ModuleFqnToConfig,
|
101 |
quantize_,
|
102 |
)
|
103 |
from torchao.quantization.granularity import PerGroup, PerAxis
|
|
|
119 |
weight_granularity=PerGroup(32),
|
120 |
weight_scale_dtype=torch.bfloat16,
|
121 |
)
|
122 |
+
quant_config = ModuleFqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config})
|
123 |
quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[])
|
124 |
|
125 |
# either use `untied_model_id` or `untied_model_local_path`
|