Loads the dataset :type dataset: string :param dataset: the path to the dataset (here MNIST)
(dataset)
| 173 | |
| 174 | |
| 175 | def load_data(dataset): |
| 176 | ''' Loads the dataset |
| 177 | |
| 178 | :type dataset: string |
| 179 | :param dataset: the path to the dataset (here MNIST) |
| 180 | ''' |
| 181 | |
| 182 | ############# |
| 183 | # LOAD DATA # |
| 184 | ############# |
| 185 | |
| 186 | # Download the MNIST dataset if it is not present |
| 187 | data_dir, data_file = os.path.split(dataset) |
| 188 | if data_dir == "" and not os.path.isfile(dataset): |
| 189 | # Check if dataset is in the data directory. |
| 190 | new_path = os.path.join( |
| 191 | os.path.split(__file__)[0], |
| 192 | "..", |
| 193 | "data", |
| 194 | dataset |
| 195 | ) |
| 196 | if os.path.isfile(new_path) or data_file == 'mnist.pkl.gz': |
| 197 | dataset = new_path |
| 198 | |
| 199 | if (not os.path.isfile(dataset)) and data_file == 'mnist.pkl.gz': |
| 200 | from six.moves import urllib |
| 201 | origin = ( |
| 202 | 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' |
| 203 | ) |
| 204 | print('Downloading data from %s' % origin) |
| 205 | urllib.request.urlretrieve(origin, dataset) |
| 206 | |
| 207 | print('... loading data') |
| 208 | |
| 209 | # Load the dataset |
| 210 | with gzip.open(dataset, 'rb') as f: |
| 211 | try: |
| 212 | train_set, valid_set, test_set = pickle.load(f, encoding='latin1') |
| 213 | except: |
| 214 | train_set, valid_set, test_set = pickle.load(f) |
| 215 | # train_set, valid_set, test_set format: tuple(input, target) |
| 216 | # input is a numpy.ndarray of 2 dimensions (a matrix) |
| 217 | # where each row corresponds to an example. target is a |
| 218 | # numpy.ndarray of 1 dimension (vector) that has the same length as |
| 219 | # the number of rows in the input. It should give the target |
| 220 | # to the example with the same index in the input. |
| 221 | |
| 222 | def shared_dataset(data_xy, borrow=True): |
| 223 | """ Function that loads the dataset into shared variables |
| 224 | |
| 225 | The reason we store our dataset in shared variables is to allow |
| 226 | Theano to copy it into the GPU memory (when code is run on GPU). |
| 227 | Since copying data into the GPU is slow, copying a minibatch everytime |
| 228 | is needed (the default behaviour if the data is not in a shared |
| 229 | variable) would lead to a large decrease in performance. |
| 230 | """ |
| 231 | data_x, data_y = data_xy |
| 232 | shared_x = theano.shared(numpy.asarray(data_x, |