| 19 | |
| 20 | |
| 21 | class DataCutter(object): |
| 22 | def __init__(self, inp, train, test, number): |
| 23 | self._input = inp |
| 24 | self._train = train |
| 25 | self._test = test |
| 26 | self._number = number |
| 27 | |
| 28 | def cut(self): |
| 29 | user_behav = dict() |
| 30 | user_ids = list() |
| 31 | with open(self._input) as f: |
| 32 | for line in f: |
| 33 | arr = line.strip().split(',') |
| 34 | if len(arr) != 5: |
| 35 | break |
| 36 | |
| 37 | if arr[0] not in user_behav: |
| 38 | user_ids.append(arr[0]) |
| 39 | user_behav[arr[0]] = list() |
| 40 | |
| 41 | user_behav[arr[0]].append(line) |
| 42 | |
| 43 | random.shuffle(user_ids) |
| 44 | test_user_ids = user_ids[:self._number] |
| 45 | train_user_ids = user_ids[self._number:] |
| 46 | |
| 47 | # write train data set |
| 48 | with open(self._train, 'w') as f: |
| 49 | for uid in train_user_ids: |
| 50 | for line in user_behav[uid]: |
| 51 | f.write(line) |
| 52 | |
| 53 | with open(self._test, 'w') as f: |
| 54 | for uid in test_user_ids: |
| 55 | for line in user_behav[uid]: |
| 56 | f.write(line) |
| 57 | |
| 58 | |
| 59 | if __name__ == '__main__': |