| 1190 | |
| 1191 | # Test prior collection using (UniformArityPrior, HardLengthPrior) |
| 1192 | def test_PriorCollection(self): |
| 1193 | |
| 1194 | # ------- TEST CASE ------- |
| 1195 | # Library |
| 1196 | args_make_tokens = { |
| 1197 | # operations |
| 1198 | "op_names" : ["mul", "add", "neg", "inv", "cos"], |
| 1199 | "use_protected_ops" : False, |
| 1200 | # input variables |
| 1201 | "input_var_ids" : {"x" : 0 , "v" : 1 , "t" : 2, }, |
| 1202 | "input_var_units" : {"x" : [1, 0, 0] , "v" : [1, -1, 0] , "t" : [0, 1, 0] }, |
| 1203 | "input_var_complexity" : {"x" : 0. , "v" : 1. , "t" : 0., }, |
| 1204 | # constants |
| 1205 | "constants" : {"pi" : np.pi , "c" : 3e8 , "M" : 1e6 }, |
| 1206 | "constants_units" : {"pi" : [0, 0, 0] , "c" : [1, -1, 0], "M" : [0, 0, 1] }, |
| 1207 | "constants_complexity" : {"pi" : 0. , "c" : 0. , "M" : 1. }, |
| 1208 | } |
| 1209 | my_lib = Lib.Library(args_make_tokens = args_make_tokens, |
| 1210 | superparent_units = [1, -2, 1], superparent_name = "y") |
| 1211 | # Programs test case |
| 1212 | max_length = 5 |
| 1213 | min_length = 3 |
| 1214 | test_case_str = np.array([ |
| 1215 | # 0 1 2 3 4 |
| 1216 | ["add", "cos", "x" , "cos", "c" ], # -> should begin enforcing arity == 0 tokens at pos = 3 |
| 1217 | # -> should begin enforcing arity <= 1 tokens at pos = 1 |
| 1218 | # -> should enforce arity >= 1 tokens until pos = 0 |
| 1219 | |
| 1220 | ["cos", "cos", "cos", "cos", "x" ], # -> should begin enforcing arity == 0 tokens at pos = 3 |
| 1221 | # -> should begin enforcing arity <= 1 tokens at pos = 2 |
| 1222 | # -> should enforce arity >= 1 tokens until pos = 1 |
| 1223 | |
| 1224 | ["add", "add", "x" , "pi" , "c" ], # -> should begin enforcing arity == 0 tokens at pos = 1 |
| 1225 | # -> should begin enforcing arity <= 1 tokens at pos = 1 |
| 1226 | # -> should enforce arity >= 1 tokens until pos = 0 |
| 1227 | |
| 1228 | ["add", "x" , "c" , "-" , "-" ], # -> should begin enforcing arity == 0 tokens at pos = inf |
| 1229 | # -> should begin enforcing arity <= 1 tokens at pos = inf |
| 1230 | # -> should enforce arity >= 1 tokens until pos = 0 |
| 1231 | |
| 1232 | ["add", "cos", "x" , "c" , "-" ], # -> should begin enforcing arity == 0 tokens at pos = inf |
| 1233 | # -> should begin enforcing arity <= 1 tokens at pos = 1 |
| 1234 | # -> should enforce arity >= 1 tokens until pos = 0 |
| 1235 | |
| 1236 | ["cos", "add", "x" , "v" , "-" ], # -> should begin enforcing arity == 0 tokens at pos = inf |
| 1237 | # -> should begin enforcing arity <= 1 tokens at pos = 1 |
| 1238 | # -> should enforce arity >= 1 tokens until pos = 1 |
| 1239 | ]) |
| 1240 | pos_begin_max_arity_is_0 = np.array([3, 3, 1, np.inf, np.inf, np.inf]) |
| 1241 | pos_begin_max_arity_is_1 = np.array([1, 2, 1, np.inf, 1 , 1 ]) |
| 1242 | pos_end_min_arity_is_1 = np.array([0, 1, 0, 0 , 0 , 1 ]) |
| 1243 | |
| 1244 | # Using a valid placeholder existing in the library that will be ignored anyway instead of '-' |
| 1245 | test_case_str = np.where(test_case_str == "-", "x", test_case_str) |
| 1246 | |
| 1247 | # Creating idx that will be appended |
| 1248 | n_progs, n_steps = test_case_str.shape |
| 1249 | test_case = np.zeros((n_progs, n_steps)).astype(int) |