Distillation get very low accuracy performance, what's the potential reason?

I have a mae model, and I hope to make it much faster with similar accuracy.

I have tried to prune the model and make it a bit faster with a bit low performance. I an now trying to use the Distillation to improve the model, just following the guide (Distillation — Model Optimizer 0.15.0), teacher model is my original model (architecture and weights), and student model is the pruned model (only the pruned model architecture with random weights), I can get the distilled model, however the accuracy is much lower than my pruned model.

What’s the potential reason for this, hope some one can give me advice on it. Than you.

Hi @relaxtheo ,
Can you help us with the performance matrix.
supported model/scripts/logs?

I can reproduce this problem on resnet model. I am new to distillation and I may make mistakes in using the distillation.

This model is base on resnet152 and trained on hymenoptera (https://download.pytorch.org/tutorial/hymenoptera_data.zip)

t2d is for distillation, t2dt is for testing.

t2d.py.txt (1.8 KB)
t2dt.py.txt (4.7 KB)

The accuracy of the trained model is 0.9477, and the distilled model is only 0.5425.

The log is below:
(pruning) user@server/data/WS/PruningWS/code$ cd /data/WS/PruningWS/code ; /usr/bin/env /data/WS/anaconda3/envs/pruning/bin/python /home/jzyq/.vscode/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/…/…/debugpy/launcher 34285 – /data/WS/PruningWS/code/t2dt.py --config config/mae1-prune3.25.yaml --gpus 0 --run …/run/ --cycle 11

/data/WS/anaconda3/envs/pruning/lib/python3.11/site-packages/modelopt/torch/quantization/tensor_quant.py:92: FutureWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.

scaled_e4m3_abstract = torch.library.impl_abstract(“trt::quantize_fp8”)(

TRAINING 1…

VALIDATING 2…

Loss: 0.1485 Acc: 0.9477

ORIGINAL MODEL: 4.07870, 3.39219

/data/WS/anaconda3/envs/pruning/lib/python3.11/tempfile.py:934: ResourceWarning: Implicitly cleaning up <TemporaryDirectory ‘/tmp/tmpkde989bv’>

_warnings.warn(warn_message, ResourceWarning)

(pruning) user@server/data/WS/PruningWS/code$ ^C

(pruning) user@server/data/WS/PruningWS/code$ cd /data/WS/PruningWS/code ; /usr/bin/env /data/WS/anaconda3/envs/pruning/bin/python /home/jzyq/.vscode/extensions/ms-python.debugpy-2024.10.0-linux-x64/bundled/libs/debugpy/adapter/…/…/debugpy/launcher 43897 – /data/WS/PruningWS/code/t2dt.py --config config/mae1-prune3.25.yaml --gpus 0 --run …/run/ --cycle 11

/data/WS/anaconda3/envs/pruning/lib/python3.11/site-packages/modelopt/torch/quantization/tensor_quant.py:92: FutureWarning: torch.library.impl_abstract was renamed to torch.library.register_fake. Please use that instead; we will remove torch.library.impl_abstract in a future version of PyTorch.

scaled_e4m3_abstract = torch.library.impl_abstract(“trt::quantize_fp8”)(

TRAINING 1…

VALIDATING 2…

Loss: 5.8928 Acc: 0.5425

DISTILL MODEL: 3.14478, 2.89372

/data/WS/anaconda3/envs/pruning/lib/python3.11/tempfile.py:934: ResourceWarning: Implicitly cleaning up <TemporaryDirectory ‘/tmp/tmpmuv5fgua’>

_warnings.warn(warn_message, ResourceWarning)