MCPcopy
hub / github.com/cs230-stanford/cs230-code-examples / input_fn

Function input_fn

tensorflow/nlp/model/input_fn.py:28–79  ·  view source on GitHub ↗

Input function for NER Args: mode: (string) 'train', 'eval' or any other mode you can think of At training, we shuffle the data and have multiple epochs sentences: (tf.Dataset) yielding list of ids of words datasets: (tf.Dataset) yielding list of ids

(mode, sentences, labels, params)

Source from the content-addressed store, hash-verified

26
27
28def input_fn(mode, sentences, labels, params):
29 """Input function for NER
30
31 Args:
32 mode: (string) 'train', 'eval' or any other mode you can think of
33 At training, we shuffle the data and have multiple epochs
34 sentences: (tf.Dataset) yielding list of ids of words
35 datasets: (tf.Dataset) yielding list of ids of tags
36 params: (Params) contains hyperparameters of the model (ex: `params.num_epochs`)
37
38 """
39 # Load all the dataset in memory for shuffling is training
40 is_training = (mode == 'train')
41 buffer_size = params.buffer_size if is_training else 1
42
43 # Zip the sentence and the labels together
44 dataset = tf.data.Dataset.zip((sentences, labels))
45
46 # Create batches and pad the sentences of different length
47 padded_shapes = ((tf.TensorShape([None]), # sentence of unknown size
48 tf.TensorShape([])), # size(words)
49 (tf.TensorShape([None]), # labels of unknown size
50 tf.TensorShape([]))) # size(tags)
51
52 padding_values = ((params.id_pad_word, # sentence padded on the right with id_pad_word
53 0), # size(words) -- unused
54 (params.id_pad_tag, # labels padded on the right with id_pad_tag
55 0)) # size(tags) -- unused
56
57
58 dataset = (dataset
59 .shuffle(buffer_size=buffer_size)
60 .padded_batch(params.batch_size, padded_shapes=padded_shapes, padding_values=padding_values)
61 .prefetch(1) # make sure you always have one batch ready to serve
62 )
63
64 # Create initializable iterator from this dataset so that we can reset at each epoch
65 iterator = dataset.make_initializable_iterator()
66
67 # Query the output of the iterator for input to the model
68 ((sentence, sentence_lengths), (labels, _)) = iterator.get_next()
69 init_op = iterator.initializer
70
71 # Build and return a dictionnary containing the nodes / ops
72 inputs = {
73 'sentence': sentence,
74 'labels': labels,
75 'sentence_lengths': sentence_lengths,
76 'iterator_init_op': init_op
77 }
78
79 return inputs

Callers 2

train.pyFile · 0.90
evaluate.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected