(model_name, img_dim, nb_classes)
| 90 | |
| 91 | |
| 92 | def load(model_name, img_dim, nb_classes): |
| 93 | |
| 94 | if model_name == "CNN": |
| 95 | model = CNN(img_dim, nb_classes, model_name=model_name) |
| 96 | if model_name == "Big_CNN": |
| 97 | model = Big_CNN(img_dim, nb_classes, model_name=model_name) |
| 98 | elif model_name == "FCN": |
| 99 | model = FCN(img_dim, nb_classes, model_name=model_name) |
| 100 | |
| 101 | return model |