Forward pass through GRU layer
(tparams, state_below, options, prefix='gru', mask=None, **kwargs)
| 366 | |
| 367 | |
| 368 | def gru_layer(tparams, state_below, options, prefix='gru', mask=None, **kwargs): |
| 369 | """ |
| 370 | Forward pass through GRU layer |
| 371 | """ |
| 372 | nsteps = state_below.shape[0] |
| 373 | if state_below.ndim == 3: |
| 374 | n_samples = state_below.shape[1] |
| 375 | else: |
| 376 | n_samples = 1 |
| 377 | |
| 378 | dim = tparams[_p(prefix,'Ux')].shape[1] |
| 379 | |
| 380 | if mask == None: |
| 381 | mask = tensor.alloc(1., state_below.shape[0], 1) |
| 382 | |
| 383 | def _slice(_x, n, dim): |
| 384 | if _x.ndim == 3: |
| 385 | return _x[:, :, n*dim:(n+1)*dim] |
| 386 | return _x[:, n*dim:(n+1)*dim] |
| 387 | |
| 388 | state_below_ = tensor.dot(state_below, tparams[_p(prefix, 'W')]) + tparams[_p(prefix, 'b')] |
| 389 | state_belowx = tensor.dot(state_below, tparams[_p(prefix, 'Wx')]) + tparams[_p(prefix, 'bx')] |
| 390 | U = tparams[_p(prefix, 'U')] |
| 391 | Ux = tparams[_p(prefix, 'Ux')] |
| 392 | |
| 393 | def _step_slice(m_, x_, xx_, h_, U, Ux): |
| 394 | preact = tensor.dot(h_, U) |
| 395 | preact += x_ |
| 396 | |
| 397 | r = tensor.nnet.sigmoid(_slice(preact, 0, dim)) |
| 398 | u = tensor.nnet.sigmoid(_slice(preact, 1, dim)) |
| 399 | |
| 400 | preactx = tensor.dot(h_, Ux) |
| 401 | preactx = preactx * r |
| 402 | preactx = preactx + xx_ |
| 403 | |
| 404 | h = tensor.tanh(preactx) |
| 405 | |
| 406 | h = u * h_ + (1. - u) * h |
| 407 | h = m_[:,None] * h + (1. - m_)[:,None] * h_ |
| 408 | |
| 409 | return h |
| 410 | |
| 411 | seqs = [mask, state_below_, state_belowx] |
| 412 | _step = _step_slice |
| 413 | |
| 414 | rval, updates = theano.scan(_step, |
| 415 | sequences=seqs, |
| 416 | outputs_info = [tensor.alloc(0., n_samples, dim)], |
| 417 | non_sequences = [tparams[_p(prefix, 'U')], |
| 418 | tparams[_p(prefix, 'Ux')]], |
| 419 | name=_p(prefix, '_layers'), |
| 420 | n_steps=nsteps, |
| 421 | profile=profile, |
| 422 | strict=True) |
| 423 | rval = [rval] |
| 424 | return rval |