MCPcopy
hub / github.com/QwenLM/Qwen-Audio / AverageShiftCalculator

Class AverageShiftCalculator

eval_audio/evaluate_srwt.py:125–239  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

123 return distance/ref_length
124
125class AverageShiftCalculator():
126 def __init__(self):
127 print("Calculating average shift.")
128
129 def __call__(self, refs, hyps):
130 ts_list1 = self.read_timestamps(refs)
131 ts_list2 = self.read_timestamps(hyps)
132 res = self.as_cal(ts_list1, ts_list2)
133 print("Average shift : {}.".format(str(res)[:8]))
134 print("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
135 return res
136
137 def _intersection(self, list1, list2):
138 set1 = set(list1)
139 set2 = set(list2)
140 if set1 == set2:
141 print("Uttid same checked.")
142 return set1
143 itsc = list(set1 & set2)
144 print("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
145 return itsc
146
147 def read_timestamps(self, body_list):
148 ts_list = []
149 pattern_error = 0
150 for body in body_list:
151 body = body.replace("<|startoftranscript|>","").replace("<|transcribe|>","")
152 ts_pattern = r"<\|\d{1,2}\.\d+\|>"
153 if "<|en|>" in body:
154 body = body.replace("<|en|>","")
155 lan = "en"
156 elif "<|zh|>" in body:
157 body = body.replace("<|zh|>","")
158 lan = "zh"
159 all_time_stamps = re.findall(ts_pattern, body)
160 all_time_stamps = [ float(t.replace("<|","").replace("|>","")) for t in all_time_stamps]
161 all_word_list = [x for x in re.split(ts_pattern, body)][1:-1]
162
163 if len(all_time_stamps) != len(all_word_list) + 1:
164 pattern_error += 1
165 continue
166 text = "\t".join(all_word_list)
167 ts = [all_time_stamps[i:i + 2] for i in range(len(all_time_stamps) - 1)]
168 ts_list.append((text, ts))
169 assert len(ts) == len(all_word_list), f"{body}"
170 print(f"pattern_error_num: {pattern_error}")
171 return ts_list
172
173 def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
174 shift_time = 0
175 for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
176 shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
177 num_tokens = len(filtered_timestamp_list1)
178 return shift_time, num_tokens
179
180 def as_cal(self, ts_list1, ts_list2):
181 # calculate average shift between timestamp1 and timestamp2
182 # when characters differ, use edit distance alignment

Callers 1

evaluate_srwt.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected