MCPcopy Index your code
hub / github.com/haitongli/knowledge-distillation-pytorch

github.com/haitongli/knowledge-distillation-pytorch @v1.0 sqlite

repository ↗ · DeepWiki ↗ · release v1.0 ↗
137 symbols 317 edges 19 files 46 documented · 34%
README

knowledge-distillation-pytorch

  • Exploring knowledge distillation of DNNs for efficient hardware solutions
  • Author: Haitong Li
  • Framework: PyTorch
  • Dataset: CIFAR-10

Features

  • A framework for exploring "shallow" and "deep" knowledge distillation (KD) experiments
  • Hyperparameters defined by "params.json" universally (avoiding long argparser commands)
  • Hyperparameter searching and result synthesizing (as a table)
  • Progress bar, tensorboard support, and checkpoint saving/loading (utils.py)
  • Pretrained teacher models available for download

Install

  • Clone the repo git clone https://github.com/peterliht/knowledge-distillation-pytorch.git

  • Install the dependencies (including Pytorch) pip install -r requirements.txt

Organizatoin:

  • ./train.py: main entrance for train/eval with or without KD on CIFAR-10
  • ./experiments/: json files for each experiment; dir for hypersearch
  • ./model/: teacher and student DNNs, knowledge distillation (KD) loss defination, dataloader

Key notes about usage for your experiments:

  • Download the zip file for pretrained teacher model checkpoints from this Box folder
  • Simply move the unzipped subfolders into 'knowledge-distillation-pytorch/experiments/' (replacing the existing ones if necessary; follow the default path naming)
  • Call train.py to start training 5-layer CNN with ResNet-18's dark knowledge, or training ResNet-18 with state-of-the-art deeper models distilled
  • Use search_hyperparams.py for hypersearch
  • Hyperparameters are defined in params.json files universally. Refer to the header of search_hyperparams.py for details

Train (dataset: CIFAR-10)

Note: all the hyperparameters can be found and modified in 'params.json' under 'model_dir'

-- Train a 5-layer CNN with knowledge distilled from a pre-trained ResNet-18 model

python train.py --model_dir experiments/cnn_distill

-- Train a ResNet-18 model with knowledge distilled from a pre-trained ResNext-29 teacher

python train.py --model_dir experiments/resnet18_distill/resnext_teacher

-- Hyperparameter search for a specified experiment ('parent_dir/params.json')

python search_hyperparams.py --parent_dir experiments/cnn_distill_alpha_temp

--Synthesize results of the recent hypersearch experiments

python synthesize_results.py --parent_dir experiments/cnn_distill_alpha_temp

Results

Quick takeaways (more details to be added):

  • Knowledge distillation provides regularization for both shallow DNNs and state-of-the-art DNNs
  • KD can also help in the scenarios of using unlabeled dataset and small amount of data for training

References

Hinton, Geoffrey, Oriol Vinyals, and Jeff Dean. "Distilling the knowledge in a neural network." arXiv preprint arXiv:1503.02531 (2015).

Romero, A., Ballas, N., Kahou, S. E., Chassang, A., Gatta, C., & Bengio, Y. (2014). Fitnets: Hints for thin deep nets. arXiv preprint arXiv:1412.6550.

https://github.com/cs230-stanford/cs230-stanford.github.io

https://github.com/bearpaw/pytorch-classification

Core symbols most depended-on inside this repo

save
called by 7
utils.py
forward
called by 7
model/resnext.py
update
called by 4
utils.py
loss_fn
called by 4
model/net.py
_make_layer
called by 4
model/resnet.py
block
called by 3
model/resnext.py
_make_denseblock
called by 3
model/densenet.py
_make_layer
called by 3
model/preresnet.py

Shape

Method 61
Function 51
Class 25

Languages

Python100%

Modules by API surface

utils.py18 symbols
model/resnet.py17 symbols
model/densenet.py16 symbols
model/preresnet.py13 symbols
model/wrn.py12 symbols
mnist/distill_mnist_unlabeled.py10 symbols
mnist/distill_mnist.py10 symbols
model/resnext.py9 symbols
model/net.py6 symbols
mnist/teacher_mnist.py6 symbols
mnist/student_mnist.py6 symbols
train.py5 symbols

Dependencies from manifests, versioned

Pillow5.0.0 · 1×
numpy1.14.0 · 1×
scipy1.0.0 · 1×
tabulate0.8.2 · 1×
tensorflow1.7.0rc0 · 1×
torch0.3.0.post4 · 1×
torchvision0.2.0 · 1×
tqdm4.19.8 · 1×

For agents

$ claude mcp add knowledge-distillation-pytorch \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact