#coding=utf-8
import os
import sys
import re
from typing import List, Union
import jiwer
import pdb
def cal_wer(path_ref, path_hyp, metric_type, output_detail, path_output):
ref_text, hyp_text, ref_key = _read_file(path_ref, path_hyp, metric_type)
cal_wer_from_list(ref_text, hyp_text, ref_key, metric_type, output_detail, path_output)
def cal_wer_from_list(
reference: Union[str, List[str]],
hypothesis: Union[str, List[str]],
key: Union[str, List[str]],
metric_type: str,
output_detail: bool,
path_output: str
):
if isinstance(reference, str):
reference = [reference]
if isinstance(hypothesis, str):
hypothesis = [hypothesis]
if isinstance(key, str):
key = [key]
# 根据ref是否为空, 先分别计算wer指标再汇总
ref_normal, hyp_normal, key_normal = [], [], []
ref_empty, hyp_empty, key_empty = [], [], []
for i in range(len(reference)):
if len(reference[i]) != 0:
ref_normal.append(reference[i])
hyp_normal.append(hypothesis[i])
key_normal.append(key[i])
else:
ref_empty.append(reference[i])
hyp_empty.append(hypothesis[i])
key_empty.append(key[i])
res_normal, out_normal = _cal_wer_normal(ref_normal, hyp_normal, metric_type)
res_empty, out_empty = _cal_wer_empty(hyp_empty, metric_type)
_summary(ref_normal, hyp_normal, res_normal, out_normal.alignments, key_normal,
hyp_empty, res_empty, out_empty, key_empty,
metric_type, output_detail, path_output)
def _read_file(path_ref, path_hyp, metric_type):
ref_key, ref_text = _preprocess(path_ref, '\t', metric_type)
hyp_key, hyp_text = _preprocess(path_hyp, '\t', metric_type)
tmp_dict = {}
tmp_text = []
for i in range(len(hyp_key)):
if hyp_key[i] not in tmp_dict.keys():
tmp_dict[hyp_key[i]] = hyp_text[i]
else:
print ("repeated key")
for i in range(len(ref_key)):
if ref_key[i] in tmp_dict.keys():
tmp_text.append(tmp_dict[ref_key[i]])
else:
tmp_text.append("")
return ref_text, tmp_text, ref_key
def _preprocess(path_in, sep, metric_type):
res_key, res_text = [], []
with open(path_in, "r", encoding="utf-8") as f_in:
lines = f_in.readlines()
for line in lines:
line = line.strip().split(sep, 1)
if len(line) == 2:
key, text = line
text = re.sub("<s>", "", text)
text = re.sub("</s>", "", text)
text = re.sub("<unk>", "", text)
text = re.sub("@@ ", "", text)
text = re.sub("@ ", "", text)
text = re.sub("@@", "", text)
text = re.sub("@", "", text)
#text = re.sub(" ", "", text)
text = text.lower()
else:
key = line[0]
text = ""
text = [x for x in text]
text_tmp = ""
if metric_type == "wer":
for ch in text:
if '\u4e00' <= ch <= '\u9fff':
text_tmp += " " + ch + " "
else:
text_tmp += ch
text = text_tmp.strip().replace(" ", " ")
elif metric_type == "cer":
text_tmp = "".join(text)
text = text_tmp.strip().replace(" ", "")
else:
assert False
res_key.append(key)
res_text.append(text)
return res_key, res_text
def _cal_wer_normal(reference, hypothesis, metric_type):
if metric_type == "wer":
out = jiwer.process_words(reference=reference, hypothesis=hypothesis)
ERR = out.wer
elif metric_type == "cer":
out = jiwer.process_characters(reference=reference, hypothesis=hypothesis)
ERR = out.cer
else:
assert False
H = out.hits
S = out.substitutions
D = out.deletions
I = out.insertions
N = H + S + D
res = [ERR, N, S, D, I]
return res, out
def _cal_wer_empty(hypothesis, metric_type):
out = []
I = 0
for hyp in hypothesis:
if hyp == "":
i = 0
else:
if metric_type == "wer":
i = len(hyp.split(" "))
elif metric_type == "cer":
i = len(hyp)
else:
assert False
I += i
out.append(i)
res = [0, 0, 0, 0, I]
return res, out
def _summary(ref_normal, hyp_normal, res_normal, out_normal, key_normal,
hyp_empty, res_empty, out_empty, key_empty,
metric_type, output_detail, path_output):
# wer/cer计算
_, N, S, D, I = res_normal
I += res_empty[-1]
if N != 0:
ERR = (S + D + I) / N
SUB = S / N
DEL = D / N
INS = I / N
N_WORD = N
else:
if I == 0:
ERR = 0
else:
ERR = 1
SUB, DEL, INS, N_WORD = 0, 0, I, 0
# 句准计算 + 详细错误指标 + 详细错误统计
utt_normal, alignments_normal, statistics_normal = _analyse_normal(
ref_normal, hyp_normal, out_normal, key_normal, metric_type)
utt_empty, alignments_empty, statistics_empty = _analyse_empty(
hyp_empty, out_empty, key_empty, metric_type)
utt = utt_normal + utt_empty
alignments = alignments_normal + alignments_empty
for key in statistics_empty['insert'].keys():
if key not in statistics_normal['insert'].keys():
statistics_normal['insert'][key] = statistics_empty['insert'][key]
else:
statistics_normal['insert'][key] += statistics_empty['insert'][key]
N_SENT = len(out_normal) + len(out_empty)
ACC_UTT = utt / N_SENT
res = [ERR, SUB, DEL, INS, N_WORD, ACC_UTT, N_SENT]
_format_output(res, alignments, statistics_normal, metric_type, output_detail, path_output)
def _analyse_normal(ref_normal, hyp_normal, out_normal, key_normal, metric_type):
utt_normal = 0
alignments_normal = []
statistics_normal = {'substitute' : {}, 'delete' : {}, 'insert' : {}}
for i, alignment in enumerate(out_normal):
err, n_hit, n_sub, n_del, n_ins = 0, 0, 0, 0, 0
ref_align, hyp_align = "", ""
sub_align, del_align, ins_align = "", "", ""
for j, chunk in enumerate(alignment):
if (metric_type == "wer" and (ref_align != "" or hyp_align != "")):
ref_align += " "
hyp_align += " "
if chunk.type == 'equal':
n_hit += chunk.ref_end_idx - chunk.ref_start_idx
ref_align += _extract_string(ref_normal[i], chunk.ref_start_idx, chunk.ref_end_idx, metric_type)
hyp_align += _extract_string(hyp_normal[i], chunk.hyp_start_idx, chunk.hyp_end_idx, metric_type)
elif chunk.type == 'substitute':
err += 1
n_sub += chunk.ref_end_idx - chunk.ref_start_idx
ref_sub = _extract_string(ref_normal[i], chunk.ref_start_idx, chunk.ref_end_idx, metric_type)
hyp_sub = _extract_string(hyp_normal[i], chunk.hyp_start_idx, chunk.hyp_end_idx, metric_type)
ref_align += ref_sub
hyp_align += hyp_sub
key_sub = "(" + ref_sub + ") --> (" + hyp_sub + ")"
sub_align += key_sub + "\t"
if key_sub not in statistics_normal['substitute'].keys():
statistics_normal['substitute'][key_sub] = 1
else:
statistics_normal['substitute'][key_sub] += 1
elif chunk.type == 'delete':
err += 1
n_del += chunk.ref_end_idx - chunk.ref_start_idx
ref_del = _extract_string(ref_normal[i], chunk.ref_start_idx, chunk.ref_end_idx, metric_type)
hyp_del = "*"
ref_align += ref_del
hyp_align += hyp_del
key_del = ref_del
del_align += key_del + "\t"
if key_del not in statistics_normal['delete'].keys():
statistics_normal['delete'][key_del] = 1
else:
statistics_normal['delete'][key_del] += 1
elif chunk.type == 'insert':
err += 1
n_ins += chunk.hyp_end_idx - chunk.hyp_start_idx
ref_ins = "*"
hyp_ins = _extract_string(hyp_normal[i], chunk.hyp_start_idx, chunk.hyp_end_idx, metric_type)
ref_align += ref_ins
hyp_align += hyp_ins
key_ins = hyp_ins
ins_align += key_ins + "\t"
if key_ins not in statistics_normal['insert'].keys():
statistics_normal['insert'][key_ins] = 1
else:
statistics_normal['insert'][key_ins] += 1
else:
assert False
if err == 0:
utt_normal += 1
alignments_normal.append((key_normal[i], ref_align, hyp_align,
sub_align, del_align, ins_align,
n_hit, n_sub, n_del, n_ins))
return utt_normal, alignments_normal, statistics_normal
def _analyse_empty(hyp_empty, out_empty, key_empty, metric_type):
utt_empty = 0
alignments_empty = []
statistics_empty = {'insert' : {}}
for i, ins in enumerate(out_empty):
ref_align, hyp_align = "", ""
sub_align, del_align, ins_align = "", "", ""
if ins == 0:
utt_empty += 1
else:
ref_ins = "*"
hyp_ins = _extract_string(hyp_empty[i], 0, len(hyp_empty[i]), metric_type)
ref_align += ref_ins
hyp_align += hyp_ins
key_ins = hyp_ins
ins_align += key_ins + "\t"
if key_ins not in statistics_empty['insert'].keys():
statistics_empty['insert'][key_ins] = 1
else:
statistics_empty['insert'][key_ins] += 1
alignments_empty.append((key_empty[i], ref_align, hyp_align,
sub_align, del_align, ins_align,
0, 0, 0, ins))
return utt_empty, alignments_empty, statistics_empty
def _extract_string(s, begin, end, metric_type):
res = ""
if metric_type == 'wer':
res = ' '.join(s.split(' ')[begin:end])
elif metric_type == 'cer':
res = s[begin:end]
else:
assert False
return res
def _format_output(res, alignments, statistics, metric_type, output_detail, path_output):
with open(path_output, "w", encoding="utf-8") as f_out:
if output_detail == True:
f_out.write("-"*100 + "\n")
for i, sample in enumerate(alignments):
key, ref, hyp = sample[0:3]
sub_align, del_align, ins_align = sample[3:6]
n_hit, n_sub, n_del, n_ins = sample[6:]
f_out.write("KEY: " + key + "\n")
f_out.write("REF: " + ref + "\n")
f_out.write("HYP: " + hyp + "\n")
f_out.write("CNT: " + "H(" + str(n_hit) + ") " + \
"S(" + str(n_sub) + ") " + \
"D(" + str(n_del) + ") " + \
"I(" + str(n_ins) + ")\n")
f_out.write("SUB: " + sub_align + "\n")
f_out.write("DEL: " + del_align + "\n")
f_out.write("INS: " + ins_align + "\n\n")
f_out.write("-"*100 + "\n")
f_out.write("-"*100 + "\n")
lst_sub = list(sorted(statistics['substitute'].items(), key = lambda x : x[1], reverse=True))
lst_del = list(sorted(statistics['delete'].items(), key = lambda x : x[1], reverse=True))
lst_ins = list(sorted(statistics['insert'].items(), key = lambda x : x[1], reverse=True))
f_out.write("\n替换错误统计: \n")
for x in lst_sub:
f_out.write("\t" + x[0] + "(" + str(x[1]) + ")" + "\n")
f_out.write("\n删除错误统计: \n")
for x in lst_del:
f_out.write("\t" + x[0] + "(" + str(x[1]) + ")" + "\n")
f_out.write("\n插入错误统计: \n")
for x in lst_ins:
f_out.write("\t" + x[0] + "(" + str(x[1]) + ")" + "\n")
f_out.write("-"*100 + "\n")
f_out.write("-"*100 + "\n")
f_out.write(metric_type.upper() + ": " + str(round(res[0] * 100.0, 2)) + '%\n')
f_out.write("WORDS: " + str(res[4]) + "\t")
f_out.write("SUB: " + str(round(res[1] * 100.0, 2)) + "%\t")
f_out.write("DEL: " + str(round(res[2] * 100.0, 2)) + "%\t")
f_out.write("INS: " + str(round(res[3] * 100.0, 2)) + "%\n")
f_out.write("ACC_UTT: " + str(round(res[5] * 100.0, 2)) + '%\t')
f_out.write("SENTS: " + str(res[6]) + '\n')
f_out.write("-"*100 + "\n")
print (metric_type + " calculation done")
print ("saved to " + path_output)
if __name__ == '__main__':
'''
# example of function cal_wer_from_list
ref = ["今 天 天 气", "hello 我 ok 的", ""]
hyp = ["今 天 天", "halo 我 ok 的 呀", "噪 声"]
key = ["000", "001", "002"]
path_output = "./example.wer"
cal_wer(ref, hyp, key, "wer", True, path_output)
ref = ["今天天气", "hello我ok的", ""]
hyp = ["今天天", "halo我ok的呀", "噪声"]
key = ["000", "001", "002"]
path_output = "./example.cer"
cal_wer_from_list(ref, hyp, key, "cer", True, path_output)
'''