MCPcopy
hub / github.com/hpcaitech/Open-Sora / main

Function main

tools/scoring/ocr/inference.py:89–154  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

87
88
89def main():
90 args = parse_args()
91
92 meta_path = args.meta_path
93 if not os.path.exists(meta_path):
94 print(f"Meta file '{meta_path}' not found. Exit.")
95 exit()
96
97 wo_ext, ext = os.path.splitext(meta_path)
98 out_path = f"{wo_ext}_ocr{ext}"
99 if args.skip_if_existing and os.path.exists(out_path):
100 print(f"Output meta file '{out_path}' already exists. Exit.")
101 exit()
102
103 cfg = Config.fromfile("./tools/scoring/ocr/dbnetpp.py")
104 colossalai.launch_from_torch({})
105
106 device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
107 DefaultScope.get_instance("ocr", scope_name="mmocr") # use mmocr Registry as default
108
109 # build model
110 model = MODELS.build(cfg.model)
111 model.init_weights()
112 model.to(device) # set data_preprocessor._device
113 print("==> Model built.")
114
115 # build dataset
116 transform = Compose(cfg.test_pipeline)
117 dataset = VideoTextDataset(meta_path=meta_path, transform=transform)
118 dataloader = DataLoader(
119 dataset,
120 batch_size=args.bs,
121 num_workers=args.num_workers,
122 sampler=DistributedSampler(
123 dataset,
124 num_replicas=dist.get_world_size(),
125 rank=dist.get_rank(),
126 shuffle=False,
127 drop_last=False,
128 ),
129 collate_fn=default_collate,
130 )
131 print("==> Dataloader built.")
132
133 # compute scores
134 dataset.meta["ocr"] = np.nan
135 indices_list = []
136 scores_list = []
137 model.eval()
138 for data in tqdm(dataloader, disable=dist.get_rank() != 0):
139 indices_i = data["index"]
140 indices_list.extend(indices_i.tolist())
141 del data["index"]
142
143 pred = model.test_step(data) # this line will cast data to device
144
145 num_texts_i = [(x.pred_instances.scores > 0.3).sum().item() for x in pred]
146 scores_list.extend(num_texts_i)

Callers 1

inference.pyFile · 0.70

Calls 6

tqdmFunction · 0.85
toMethod · 0.80
parse_argsFunction · 0.70
VideoTextDatasetClass · 0.70
merge_scoresFunction · 0.70
deviceMethod · 0.45

Tested by

no test coverage detected