(self, path, tokenizer: BertTokenizer, max_seq_length, readin: int = 2000000, dupe_factor: int = 5, small_seq_prob: float = 0.1)
| 105 | |
| 106 | class PretrainingDataCreator: |
| 107 | def __init__(self, path, tokenizer: BertTokenizer, max_seq_length, readin: int = 2000000, dupe_factor: int = 5, small_seq_prob: float = 0.1): |
| 108 | self.dupe_factor = dupe_factor |
| 109 | self.max_seq_length = max_seq_length |
| 110 | self.small_seq_prob = small_seq_prob |
| 111 | |
| 112 | documents = [] |
| 113 | instances = [] |
| 114 | with open(path, encoding='utf-8') as fd: |
| 115 | for i, line in enumerate(tqdm(fd)): |
| 116 | line = line.replace('\n', '') |
| 117 | # Expected format (Q,T,U,S,D) |
| 118 | # query, title, url, snippet, document = line.split('\t') |
| 119 | # ! remove this following line later |
| 120 | document = line |
| 121 | if len(document.split("<sep>")) <= 3: |
| 122 | continue |
| 123 | lines = document.split("<sep>") |
| 124 | document = [] |
| 125 | for seq in lines: |
| 126 | document.append(tokenizer.tokenize(seq)) |
| 127 | # document = list(map(tokenizer.tokenize, lines)) |
| 128 | documents.append(document) |
| 129 | |
| 130 | documents = [x for x in documents if x] |
| 131 | |
| 132 | self.documents = documents |
| 133 | for _ in range(self.dupe_factor): |
| 134 | for index in range(len(self.documents)): |
| 135 | instances.extend(self.create_training_instance(index)) |
| 136 | |
| 137 | shuffle(instances) |
| 138 | self.instances = instances |
| 139 | self.len = len(self.instances) |
| 140 | self.documents = None |
| 141 | documents = None |
| 142 | |
| 143 | def __len__(self): |
| 144 | return self.len |
nothing calls this directly
no test coverage detected