Convert between text-label and text-index
| 271 | |
| 272 | |
| 273 | class CTCLabelConverter(object): |
| 274 | """ Convert between text-label and text-index """ |
| 275 | |
| 276 | def __init__(self, character, separator_list = {}, dict_pathlist = {}): |
| 277 | # character (str): set of the possible characters. |
| 278 | dict_character = list(character) |
| 279 | |
| 280 | self.dict = {} |
| 281 | for i, char in enumerate(dict_character): |
| 282 | self.dict[char] = i + 1 |
| 283 | |
| 284 | self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0) |
| 285 | |
| 286 | self.separator_list = separator_list |
| 287 | separator_char = [] |
| 288 | for lang, sep in separator_list.items(): |
| 289 | separator_char += sep |
| 290 | self.ignore_idx = [0] + [i+1 for i,item in enumerate(separator_char)] |
| 291 | |
| 292 | ####### latin dict |
| 293 | if len(separator_list) == 0: |
| 294 | dict_list = [] |
| 295 | for lang, dict_path in dict_pathlist.items(): |
| 296 | try: |
| 297 | with open(dict_path, "r", encoding = "utf-8-sig") as input_file: |
| 298 | word_count = input_file.read().splitlines() |
| 299 | dict_list += word_count |
| 300 | except: |
| 301 | pass |
| 302 | else: |
| 303 | dict_list = {} |
| 304 | for lang, dict_path in dict_pathlist.items(): |
| 305 | with open(dict_path, "r", encoding = "utf-8-sig") as input_file: |
| 306 | word_count = input_file.read().splitlines() |
| 307 | dict_list[lang] = word_count |
| 308 | |
| 309 | self.dict_list = dict_list |
| 310 | |
| 311 | def encode(self, text, batch_max_length=25): |
| 312 | """convert text-label into text-index. |
| 313 | input: |
| 314 | text: text labels of each image. [batch_size] |
| 315 | |
| 316 | output: |
| 317 | text: concatenated text index for CTCLoss. |
| 318 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] |
| 319 | length: length of each text. [batch_size] |
| 320 | """ |
| 321 | length = [len(s) for s in text] |
| 322 | text = ''.join(text) |
| 323 | text = [self.dict[char] for char in text] |
| 324 | |
| 325 | return (torch.IntTensor(text), torch.IntTensor(length)) |
| 326 | |
| 327 | def decode_greedy(self, text_index, length): |
| 328 | """ convert text-index into text-label. """ |
| 329 | texts = [] |
| 330 | index = 0 |
no outgoing calls
no test coverage detected
searching dependent graphs…