(self)
| 15 | class TestDataset(TestAutoData): |
| 16 | @pytest.mark.slow |
| 17 | def testTSDataset(self): |
| 18 | tsdh = TSDatasetH( |
| 19 | handler={ |
| 20 | "class": "Alpha158", |
| 21 | "module_path": "qlib.contrib.data.handler", |
| 22 | "kwargs": { |
| 23 | "start_time": "2017-01-01", |
| 24 | "end_time": "2020-08-01", |
| 25 | "fit_start_time": "2017-01-01", |
| 26 | "fit_end_time": "2017-12-31", |
| 27 | "instruments": "csi300", |
| 28 | "infer_processors": [ |
| 29 | {"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}}, |
| 30 | {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": "true"}}, |
| 31 | {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, |
| 32 | ], |
| 33 | "learn_processors": [ |
| 34 | "DropnaLabel", |
| 35 | {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm |
| 36 | ], |
| 37 | }, |
| 38 | }, |
| 39 | segments={ |
| 40 | "train": ("2017-01-01", "2017-12-31"), |
| 41 | "valid": ("2018-01-01", "2018-12-31"), |
| 42 | "test": ("2019-01-01", "2020-08-01"), |
| 43 | }, |
| 44 | ) |
| 45 | tsds_train = tsdh.prepare("train", data_key=DataHandlerLP.DK_L) # Test the correctness |
| 46 | tsds = tsdh.prepare("valid", data_key=DataHandlerLP.DK_L) |
| 47 | |
| 48 | t = time.time() |
| 49 | for idx in np.random.randint(0, len(tsds_train), size=2000): |
| 50 | _ = tsds_train[idx] |
| 51 | print(f"2000 sample takes {time.time() - t}s") |
| 52 | |
| 53 | t = time.time() |
| 54 | for _ in range(20): |
| 55 | data = tsds_train[np.random.randint(0, len(tsds_train), size=2000)] |
| 56 | print(data.shape) |
| 57 | print(f"2000 sample(batch index) * 20 times takes {time.time() - t}s") |
| 58 | |
| 59 | # The dimension of sample is same as tabular data, but it will return timeseries data of the sample |
| 60 | |
| 61 | # We have two method to get the time-series of a sample |
| 62 | |
| 63 | # 1) sample by int index directly |
| 64 | tsds[len(tsds) - 1] |
| 65 | |
| 66 | # 2) sample by <datetime,instrument> index |
| 67 | data_from_ds = tsds["2017-12-31", "SZ300315"] |
| 68 | |
| 69 | # Check the data |
| 70 | # Get data from DataFrame Directly |
| 71 | data_from_df = ( |
| 72 | tsdh.handler.fetch(data_key=DataHandlerLP.DK_L) |
| 73 | .loc(axis=0)["2017-01-01":"2017-12-31", "SZ300315"] |
| 74 | .iloc[-30:] |
nothing calls this directly
no test coverage detected