| 123 | return distance/ref_length |
| 124 | |
| 125 | class AverageShiftCalculator(): |
| 126 | def __init__(self): |
| 127 | print("Calculating average shift.") |
| 128 | |
| 129 | def __call__(self, refs, hyps): |
| 130 | ts_list1 = self.read_timestamps(refs) |
| 131 | ts_list2 = self.read_timestamps(hyps) |
| 132 | res = self.as_cal(ts_list1, ts_list2) |
| 133 | print("Average shift : {}.".format(str(res)[:8])) |
| 134 | print("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid)) |
| 135 | return res |
| 136 | |
| 137 | def _intersection(self, list1, list2): |
| 138 | set1 = set(list1) |
| 139 | set2 = set(list2) |
| 140 | if set1 == set2: |
| 141 | print("Uttid same checked.") |
| 142 | return set1 |
| 143 | itsc = list(set1 & set2) |
| 144 | print("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc))) |
| 145 | return itsc |
| 146 | |
| 147 | def read_timestamps(self, body_list): |
| 148 | ts_list = [] |
| 149 | pattern_error = 0 |
| 150 | for body in body_list: |
| 151 | body = body.replace("<|startoftranscript|>","").replace("<|transcribe|>","") |
| 152 | ts_pattern = r"<\|\d{1,2}\.\d+\|>" |
| 153 | if "<|en|>" in body: |
| 154 | body = body.replace("<|en|>","") |
| 155 | lan = "en" |
| 156 | elif "<|zh|>" in body: |
| 157 | body = body.replace("<|zh|>","") |
| 158 | lan = "zh" |
| 159 | all_time_stamps = re.findall(ts_pattern, body) |
| 160 | all_time_stamps = [ float(t.replace("<|","").replace("|>","")) for t in all_time_stamps] |
| 161 | all_word_list = [x for x in re.split(ts_pattern, body)][1:-1] |
| 162 | |
| 163 | if len(all_time_stamps) != len(all_word_list) + 1: |
| 164 | pattern_error += 1 |
| 165 | continue |
| 166 | text = "\t".join(all_word_list) |
| 167 | ts = [all_time_stamps[i:i + 2] for i in range(len(all_time_stamps) - 1)] |
| 168 | ts_list.append((text, ts)) |
| 169 | assert len(ts) == len(all_word_list), f"{body}" |
| 170 | print(f"pattern_error_num: {pattern_error}") |
| 171 | return ts_list |
| 172 | |
| 173 | def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2): |
| 174 | shift_time = 0 |
| 175 | for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2): |
| 176 | shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1]) |
| 177 | num_tokens = len(filtered_timestamp_list1) |
| 178 | return shift_time, num_tokens |
| 179 | |
| 180 | def as_cal(self, ts_list1, ts_list2): |
| 181 | # calculate average shift between timestamp1 and timestamp2 |
| 182 | # when characters differ, use edit distance alignment |