| 220 | |
| 221 | |
| 222 | def test_assembler_with_identity(tmpdir_factory, real_assemblies): |
| 223 | with open(os.path.join(TEST_DATA_DIR, "trimouse_full.pickle"), "rb") as file: |
| 224 | data = pickle.load(file) |
| 225 | |
| 226 | # Generate fake identity predictions |
| 227 | for k, v in data.items(): |
| 228 | if k != "metadata": |
| 229 | conf = v["confidence"] |
| 230 | ids = [np.random.rand(c.shape[0], 3) for c in conf] |
| 231 | v["identity"] = ids |
| 232 | |
| 233 | ass = inferenceutils.Assembler(data, max_n_individuals=3, n_multibodyparts=12) |
| 234 | assert ass._has_identity |
| 235 | assert len(ass.metadata["imnames"]) == 50 |
| 236 | assert ass.n_keypoints == 12 |
| 237 | assert len(ass.graph) == len(ass.paf_inds) == 66 |
| 238 | # Assemble based on the smallest graph to speed up testing |
| 239 | naive_graph = [ |
| 240 | [0, 1], |
| 241 | [7, 8], |
| 242 | [6, 7], |
| 243 | [10, 11], |
| 244 | [4, 5], |
| 245 | [5, 6], |
| 246 | [8, 9], |
| 247 | [9, 10], |
| 248 | [0, 3], |
| 249 | [3, 4], |
| 250 | [0, 2], |
| 251 | ] |
| 252 | ass.paf_inds = [ass.graph.index(edge) for edge in naive_graph] |
| 253 | ass.assemble() |
| 254 | assert not ass.unique |
| 255 | assert len(ass.assemblies) == len(real_assemblies) |
| 256 | assert sum(1 for a in ass.assemblies.values() for _ in a) == sum(1 for a in real_assemblies.values() for _ in a) |
| 257 | assert all(np.all(_.data[:, -1] != -1) for a in ass.assemblies.values() for _ in a) |
| 258 | |
| 259 | # Test now with identity only and ensure assemblies |
| 260 | # contain only parts of a single group ID. |
| 261 | ass.identity_only = True |
| 262 | ass.assemble() |
| 263 | assert len(ass.assemblies) == len(real_assemblies) |
| 264 | eq = [] |
| 265 | for a in ass.assemblies.values(): |
| 266 | for _ in a: |
| 267 | ids = _.data[:, -1] |
| 268 | ids = ids[~np.isnan(ids)] |
| 269 | eq.append(np.all(ids == ids[0])) |
| 270 | assert all(eq) |
| 271 | |
| 272 | output_dir = tmpdir_factory.mktemp("data") |
| 273 | ass.to_h5(output_dir.join("fake.h5")) |
| 274 | ass.to_pickle(output_dir.join("fake.pickle")) |
| 275 | |
| 276 | |
| 277 | def test_assembler_calibration(real_assemblies): |