MCPcopy
hub / github.com/facebookresearch/MetaCLIP / __iter__

Method __iter__

altogether/infer.py:80–134  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

78 return f"{self.root_dir}/{shard_id % 100}/{shard_id}.tar"
79
80 def __iter__(self):
81 shard_id = self.shard_id
82
83 # v2 374px is face blurred. iterate over all image-text pairs until needs legal mitigation;
84 tarball_path = self._get_tarball_path(shard_id)
85
86 with tarfile.open(tarball_path) as tar:
87 img_uuid, json_uuid = None, None
88 members = tar.getmembers()
89 # metaclip_v1 can be iterative but the paper uses mmap for random access.
90 for member in members:
91 # read jpeg first and json next
92 if member.name.endswith(".jpeg"):
93 img_uuid = member.name[:-len(".jpeg")]
94 if img_uuid.startswith("./"):
95 img_uuid = img_uuid[len("./"):]
96 with tar.extractfile(member) as f:
97 img = f.read()
98
99 elif member.name.endswith(".json"):
100 json_uuid = member.name[:-len(".json")]
101 if json_uuid.startswith("./"):
102 json_uuid = json_uuid[len("./"):]
103 with tar.extractfile(member) as f:
104 text_json = json.load(f)
105 else:
106 print(f"unknown file ext {member.name}")
107 continue
108
109 if img_uuid is None or json_uuid is None:
110 continue
111
112 assert img_uuid == json_uuid
113
114 with Image.open(BytesIO(img)) as img:
115 image = img.convert("RGB")
116 # assert "face_bbox" in text_json
117 # image = fairblurbbox(image, text_json["face_bbox"])
118 image = self.transform(image)
119
120 pad_token_id = self.args.clipcap_args["pad_token_id"]
121 prefix_length = self.args.clipcap_args["prefix_length"]
122
123 alt_text = "; ".join(list(set([txt_tuple[1] for txt_tuple in text_json["texts"]])))
124
125 prompt = alt_text
126
127 prompt_input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)[:self.args.rewrite_prompt]
128 padding_len = self.args.rewrite_prompt - len(prompt_input_ids)
129 prompt_input_ids = [pad_token_id] * padding_len + prompt_input_ids
130 tokens = prompt_input_ids
131 tokens = torch.tensor(tokens, dtype=torch.long)
132
133 yield image, tokens, img_uuid
134 img_uuid, json_uuid = None, None
135
136
137def get_dataloader(args, batch_size, data_path, shard_id, transform, tokenize):

Callers

nothing calls this directly

Calls 2

_get_tarball_pathMethod · 0.95
encodeMethod · 0.80

Tested by

no test coverage detected