MCPcopy
hub / github.com/dennybritz/cnn-text-classification-tf / preprocess

Function preprocess

train.py:44–73  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

42# print("")
43
44def preprocess():
45 # Data Preparation
46 # ==================================================
47
48 # Load data
49 print("Loading data...")
50 x_text, y = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
51
52 # Build vocabulary
53 max_document_length = max([len(x.split(" ")) for x in x_text])
54 vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
55 x = np.array(list(vocab_processor.fit_transform(x_text)))
56
57 # Randomly shuffle data
58 np.random.seed(10)
59 shuffle_indices = np.random.permutation(np.arange(len(y)))
60 x_shuffled = x[shuffle_indices]
61 y_shuffled = y[shuffle_indices]
62
63 # Split train/test set
64 # TODO: This is very crude, should use cross-validation
65 dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
66 x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
67 y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]
68
69 del x, y, x_shuffled, y_shuffled
70
71 print("Vocabulary Size: {:d}".format(len(vocab_processor.vocabulary_)))
72 print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))
73 return x_train, y_train, vocab_processor, x_dev, y_dev
74
75def train(x_train, y_train, vocab_processor, x_dev, y_dev):
76 # Training

Callers 1

mainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected