MCPcopy Index your code
hub / github.com/modelscope/DiffSynth-Studio / TextImageDataset

Class TextImageDataset

diffsynth/data/simple_text_image.py:8–41  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

6
7
8class TextImageDataset(torch.utils.data.Dataset):
9 def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
10 self.steps_per_epoch = steps_per_epoch
11 metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
12 self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
13 self.text = metadata["text"].to_list()
14 self.height = height
15 self.width = width
16 self.image_processor = transforms.Compose(
17 [
18 transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
19 transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
20 transforms.ToTensor(),
21 transforms.Normalize([0.5], [0.5]),
22 ]
23 )
24
25
26 def __getitem__(self, index):
27 data_id = torch.randint(0, len(self.path), (1,))[0]
28 data_id = (data_id + index) % len(self.path) # For fixed seed.
29 text = self.text[data_id]
30 image = Image.open(self.path[data_id]).convert("RGB")
31 target_height, target_width = self.height, self.width
32 width, height = image.size
33 scale = max(target_width / width, target_height / height)
34 shape = [round(height*scale),round(width*scale)]
35 image = torchvision.transforms.functional.resize(image,shape,interpolation=transforms.InterpolationMode.BILINEAR)
36 image = self.image_processor(image)
37 return {"text": text, "image": image}
38
39
40 def __len__(self):
41 return self.steps_per_epoch

Callers 1

launch_training_taskFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected