(data_dir, backend='nccl', filename="databunch.pkl")
| 333 | |
| 334 | @staticmethod |
| 335 | def load(data_dir, backend='nccl', filename="databunch.pkl"): |
| 336 | |
| 337 | try: |
| 338 | torch.distributed.init_process_group(backend=backend, |
| 339 | init_method="tcp://localhost:23459", |
| 340 | rank=0, world_size=1) |
| 341 | except: |
| 342 | pass |
| 343 | |
| 344 | tmp_path = data_dir/'tmp' |
| 345 | with open(str(tmp_path/filename), "rb") as f: |
| 346 | databunch = pickle.load(f) |
| 347 | |
| 348 | return databunch |
| 349 | |
| 350 | def __init__(self, data_dir, label_dir, tokenizer, train_file='train.csv', val_file='val.csv', test_data=None, |
| 351 | label_file='labels.csv', text_col='text', label_col='label', bs=32, maxlen=512, |
no outgoing calls
no test coverage detected