(self, sample)
| 239 | self._train_dataloader.num_samples = num_samples |
| 240 | |
| 241 | def load_external_caption(self, sample): |
| 242 | |
| 243 | if 'txt' not in sample.keys(): |
| 244 | sample['txt'] = '' |
| 245 | |
| 246 | if 'SA1B' in sample['__key__']: |
| 247 | captionf = f"{self.external_caption_path}/{sample['__key__'].split('/')[-1]}.txt" |
| 248 | if os.path.exists(captionf): |
| 249 | with open(captionf, "r") as reader: |
| 250 | captions = reader.readlines()[0].replace('\n', '') |
| 251 | else: |
| 252 | captions = "" |
| 253 | |
| 254 | # for captioning |
| 255 | if self.is_captioning: |
| 256 | if self.add_caption_prompt is not None: |
| 257 | prompt = random.sample(self.caption_prompt, 1)[0] |
| 258 | sample['txt'] = prompt + ' ' + captions |
| 259 | else: |
| 260 | sample['txt'] = captions |
| 261 | # for generation |
| 262 | else: |
| 263 | # randomly choose short and long captions |
| 264 | if random.random() < 0.5: |
| 265 | sample['txt'] = captions.split('.')[0] |
| 266 | else: |
| 267 | sample['txt'] = captions |
| 268 | |
| 269 | sample['txt'] = remove_prefix(sample['txt']) |
| 270 | |
| 271 | return sample |
| 272 | |
| 273 | elif 'laion' in sample['__url__']: |
| 274 | captionf = f"{self.external_laion12m_caption_path}/{sample['__url__'].split('/')[-1].split('.')[0]}/{sample['__key__']}.caption" |
| 275 | if os.path.exists(captionf): |
| 276 | with open(captionf, "r") as reader: |
| 277 | captions = reader.readlines()[0].replace('\n', '') |
| 278 | else: |
| 279 | captions = "" |
| 280 | |
| 281 | # for captioning |
| 282 | if self.is_captioning: |
| 283 | if self.add_caption_prompt is not None: |
| 284 | prompt = random.sample(self.caption_prompt, 1)[0] |
| 285 | sample['txt'] = prompt + ' ' + captions |
| 286 | else: |
| 287 | sample['txt'] = captions |
| 288 | # for generation |
| 289 | else: |
| 290 | # randomly choose short and long captions |
| 291 | if random.random() < 0.5: |
| 292 | sample['txt'] = captions.split('.')[0] |
| 293 | else: |
| 294 | sample['txt'] = captions |
| 295 | |
| 296 | sample['txt'] = remove_prefix(sample['txt']) |
| 297 | |
| 298 | return sample |
nothing calls this directly
no test coverage detected