335 lines
10 KiB
Python
335 lines
10 KiB
Python
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
import re
|
||
|
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
|
||
|
class EvalAIAnswerProcessor:
|
||
|
"""
|
||
|
Processes an answer similar to Eval AI
|
||
|
copied from
|
||
|
https://github.com/facebookresearch/mmf/blob/c46b3b3391275b4181567db80943473a89ab98ab/pythia/tasks/processors.py#L897
|
||
|
"""
|
||
|
|
||
|
CONTRACTIONS = {
|
||
|
"aint": "ain't",
|
||
|
"arent": "aren't",
|
||
|
"cant": "can't",
|
||
|
"couldve": "could've",
|
||
|
"couldnt": "couldn't",
|
||
|
"couldn'tve": "couldn't've",
|
||
|
"couldnt've": "couldn't've",
|
||
|
"didnt": "didn't",
|
||
|
"doesnt": "doesn't",
|
||
|
"dont": "don't",
|
||
|
"hadnt": "hadn't",
|
||
|
"hadnt've": "hadn't've",
|
||
|
"hadn'tve": "hadn't've",
|
||
|
"hasnt": "hasn't",
|
||
|
"havent": "haven't",
|
||
|
"hed": "he'd",
|
||
|
"hed've": "he'd've",
|
||
|
"he'dve": "he'd've",
|
||
|
"hes": "he's",
|
||
|
"howd": "how'd",
|
||
|
"howll": "how'll",
|
||
|
"hows": "how's",
|
||
|
"Id've": "I'd've",
|
||
|
"I'dve": "I'd've",
|
||
|
"Im": "I'm",
|
||
|
"Ive": "I've",
|
||
|
"isnt": "isn't",
|
||
|
"itd": "it'd",
|
||
|
"itd've": "it'd've",
|
||
|
"it'dve": "it'd've",
|
||
|
"itll": "it'll",
|
||
|
"let's": "let's",
|
||
|
"maam": "ma'am",
|
||
|
"mightnt": "mightn't",
|
||
|
"mightnt've": "mightn't've",
|
||
|
"mightn'tve": "mightn't've",
|
||
|
"mightve": "might've",
|
||
|
"mustnt": "mustn't",
|
||
|
"mustve": "must've",
|
||
|
"neednt": "needn't",
|
||
|
"notve": "not've",
|
||
|
"oclock": "o'clock",
|
||
|
"oughtnt": "oughtn't",
|
||
|
"ow's'at": "'ow's'at",
|
||
|
"'ows'at": "'ow's'at",
|
||
|
"'ow'sat": "'ow's'at",
|
||
|
"shant": "shan't",
|
||
|
"shed've": "she'd've",
|
||
|
"she'dve": "she'd've",
|
||
|
"she's": "she's",
|
||
|
"shouldve": "should've",
|
||
|
"shouldnt": "shouldn't",
|
||
|
"shouldnt've": "shouldn't've",
|
||
|
"shouldn'tve": "shouldn't've",
|
||
|
"somebody'd": "somebodyd",
|
||
|
"somebodyd've": "somebody'd've",
|
||
|
"somebody'dve": "somebody'd've",
|
||
|
"somebodyll": "somebody'll",
|
||
|
"somebodys": "somebody's",
|
||
|
"someoned": "someone'd",
|
||
|
"someoned've": "someone'd've",
|
||
|
"someone'dve": "someone'd've",
|
||
|
"someonell": "someone'll",
|
||
|
"someones": "someone's",
|
||
|
"somethingd": "something'd",
|
||
|
"somethingd've": "something'd've",
|
||
|
"something'dve": "something'd've",
|
||
|
"somethingll": "something'll",
|
||
|
"thats": "that's",
|
||
|
"thered": "there'd",
|
||
|
"thered've": "there'd've",
|
||
|
"there'dve": "there'd've",
|
||
|
"therere": "there're",
|
||
|
"theres": "there's",
|
||
|
"theyd": "they'd",
|
||
|
"theyd've": "they'd've",
|
||
|
"they'dve": "they'd've",
|
||
|
"theyll": "they'll",
|
||
|
"theyre": "they're",
|
||
|
"theyve": "they've",
|
||
|
"twas": "'twas",
|
||
|
"wasnt": "wasn't",
|
||
|
"wed've": "we'd've",
|
||
|
"we'dve": "we'd've",
|
||
|
"weve": "we've",
|
||
|
"werent": "weren't",
|
||
|
"whatll": "what'll",
|
||
|
"whatre": "what're",
|
||
|
"whats": "what's",
|
||
|
"whatve": "what've",
|
||
|
"whens": "when's",
|
||
|
"whered": "where'd",
|
||
|
"wheres": "where's",
|
||
|
"whereve": "where've",
|
||
|
"whod": "who'd",
|
||
|
"whod've": "who'd've",
|
||
|
"who'dve": "who'd've",
|
||
|
"wholl": "who'll",
|
||
|
"whos": "who's",
|
||
|
"whove": "who've",
|
||
|
"whyll": "why'll",
|
||
|
"whyre": "why're",
|
||
|
"whys": "why's",
|
||
|
"wont": "won't",
|
||
|
"wouldve": "would've",
|
||
|
"wouldnt": "wouldn't",
|
||
|
"wouldnt've": "wouldn't've",
|
||
|
"wouldn'tve": "wouldn't've",
|
||
|
"yall": "y'all",
|
||
|
"yall'll": "y'all'll",
|
||
|
"y'allll": "y'all'll",
|
||
|
"yall'd've": "y'all'd've",
|
||
|
"y'alld've": "y'all'd've",
|
||
|
"y'all'dve": "y'all'd've",
|
||
|
"youd": "you'd",
|
||
|
"youd've": "you'd've",
|
||
|
"you'dve": "you'd've",
|
||
|
"youll": "you'll",
|
||
|
"youre": "you're",
|
||
|
"youve": "you've",
|
||
|
}
|
||
|
|
||
|
NUMBER_MAP = {
|
||
|
"none": "0",
|
||
|
"zero": "0",
|
||
|
"one": "1",
|
||
|
"two": "2",
|
||
|
"three": "3",
|
||
|
"four": "4",
|
||
|
"five": "5",
|
||
|
"six": "6",
|
||
|
"seven": "7",
|
||
|
"eight": "8",
|
||
|
"nine": "9",
|
||
|
"ten": "10",
|
||
|
}
|
||
|
ARTICLES = ["a", "an", "the"]
|
||
|
PERIOD_STRIP = re.compile(r"(?!<=\d)(\.)(?!\d)")
|
||
|
COMMA_STRIP = re.compile(r"(?<=\d)(\,)+(?=\d)")
|
||
|
PUNCTUATIONS = [
|
||
|
";",
|
||
|
r"/",
|
||
|
"[",
|
||
|
"]",
|
||
|
'"',
|
||
|
"{",
|
||
|
"}",
|
||
|
"(",
|
||
|
")",
|
||
|
"=",
|
||
|
"+",
|
||
|
"\\",
|
||
|
"_",
|
||
|
"-",
|
||
|
">",
|
||
|
"<",
|
||
|
"@",
|
||
|
"`",
|
||
|
",",
|
||
|
"?",
|
||
|
"!",
|
||
|
]
|
||
|
|
||
|
def __init__(self, *args, **kwargs):
|
||
|
pass
|
||
|
|
||
|
def word_tokenize(self, word):
|
||
|
word = word.lower()
|
||
|
word = word.replace(",", "").replace("?", "").replace("'s", " 's")
|
||
|
return word.strip()
|
||
|
|
||
|
def process_punctuation(self, in_text):
|
||
|
out_text = in_text
|
||
|
for p in self.PUNCTUATIONS:
|
||
|
if (p + " " in in_text or " " + p in in_text) or (
|
||
|
re.search(self.COMMA_STRIP, in_text) is not None
|
||
|
):
|
||
|
out_text = out_text.replace(p, "")
|
||
|
else:
|
||
|
out_text = out_text.replace(p, " ")
|
||
|
out_text = self.PERIOD_STRIP.sub("", out_text, re.UNICODE)
|
||
|
return out_text
|
||
|
|
||
|
def process_digit_article(self, in_text):
|
||
|
out_text = []
|
||
|
temp_text = in_text.lower().split()
|
||
|
for word in temp_text:
|
||
|
word = self.NUMBER_MAP.setdefault(word, word)
|
||
|
if word not in self.ARTICLES:
|
||
|
out_text.append(word)
|
||
|
else:
|
||
|
pass
|
||
|
for word_id, word in enumerate(out_text):
|
||
|
if word in self.CONTRACTIONS:
|
||
|
out_text[word_id] = self.CONTRACTIONS[word]
|
||
|
out_text = " ".join(out_text)
|
||
|
return out_text
|
||
|
|
||
|
def __call__(self, item):
|
||
|
item = self.word_tokenize(item)
|
||
|
item = item.replace("\n", " ").replace("\t", " ").strip()
|
||
|
item = self.process_punctuation(item)
|
||
|
item = self.process_digit_article(item)
|
||
|
return item
|
||
|
|
||
|
|
||
|
class TextVQAAccuracyEvaluator:
|
||
|
def __init__(self):
|
||
|
self.answer_processor = EvalAIAnswerProcessor()
|
||
|
|
||
|
def _compute_answer_scores(self, raw_answers):
|
||
|
"""
|
||
|
compute the accuracy (soft score) of human answers
|
||
|
"""
|
||
|
answers = [self.answer_processor(a) for a in raw_answers]
|
||
|
assert len(answers) == 10
|
||
|
gt_answers = list(enumerate(answers))
|
||
|
unique_answers = set(answers)
|
||
|
unique_answer_scores = {}
|
||
|
|
||
|
for unique_answer in unique_answers:
|
||
|
accs = []
|
||
|
for gt_answer in gt_answers:
|
||
|
other_answers = [item for item in gt_answers if item != gt_answer]
|
||
|
matching_answers = [
|
||
|
item for item in other_answers if item[1] == unique_answer
|
||
|
]
|
||
|
acc = min(1, float(len(matching_answers)) / 3)
|
||
|
accs.append(acc)
|
||
|
unique_answer_scores[unique_answer] = sum(accs) / len(accs)
|
||
|
|
||
|
return unique_answer_scores
|
||
|
|
||
|
def eval_pred_list(self, pred_list):
|
||
|
pred_scores = []
|
||
|
for entry in tqdm(pred_list):
|
||
|
pred_answer = self.answer_processor(entry["pred_answer"])
|
||
|
unique_answer_scores = self._compute_answer_scores(entry["gt_answers"])
|
||
|
score = unique_answer_scores.get(pred_answer, 0.0)
|
||
|
pred_scores.append(score)
|
||
|
|
||
|
accuracy = sum(pred_scores) / len(pred_scores)
|
||
|
return accuracy
|
||
|
|
||
|
|
||
|
class STVQAAccuracyEvaluator:
|
||
|
def __init__(self):
|
||
|
self.answer_processor = EvalAIAnswerProcessor()
|
||
|
|
||
|
def eval_pred_list(self, pred_list):
|
||
|
pred_scores = []
|
||
|
for entry in pred_list:
|
||
|
pred_answer = self.answer_processor(entry["pred_answer"])
|
||
|
gts = [self.answer_processor(a) for a in entry["gt_answers"]]
|
||
|
score = 1.0 if pred_answer in gts else 0.0
|
||
|
pred_scores.append(score)
|
||
|
|
||
|
accuracy = sum(pred_scores) / len(pred_scores)
|
||
|
return accuracy
|
||
|
|
||
|
|
||
|
class STVQAANLSEvaluator:
|
||
|
def __init__(self):
|
||
|
import editdistance # install with `pip install editdistance`
|
||
|
|
||
|
self.get_edit_distance = editdistance.eval
|
||
|
|
||
|
def get_anls(self, s1, s2):
|
||
|
s1 = s1.lower().strip()
|
||
|
s2 = s2.lower().strip()
|
||
|
iou = 1 - self.get_edit_distance(s1, s2) / max(len(s1), len(s2))
|
||
|
anls = iou if iou >= 0.5 else 0.0
|
||
|
return anls
|
||
|
|
||
|
def eval_pred_list(self, pred_list):
|
||
|
pred_scores = []
|
||
|
for entry in pred_list:
|
||
|
anls = max(
|
||
|
self.get_anls(entry["pred_answer"], gt) for gt in entry["gt_answers"]
|
||
|
)
|
||
|
pred_scores.append(anls)
|
||
|
|
||
|
accuracy = sum(pred_scores) / len(pred_scores)
|
||
|
return accuracy
|
||
|
|
||
|
|
||
|
class TextCapsBleu4Evaluator:
|
||
|
def __init__(self):
|
||
|
# The following script requires Java 1.8.0 and pycocotools installed.
|
||
|
# The pycocoevalcap can be installed with pip as
|
||
|
# pip install git+https://github.com/ronghanghu/coco-caption.git@python23
|
||
|
# Original pycocoevalcap code is at https://github.com/tylin/coco-caption
|
||
|
# but has no python3 support yet.
|
||
|
try:
|
||
|
from pycocoevalcap.bleu.bleu import Bleu
|
||
|
from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
|
||
|
except ModuleNotFoundError:
|
||
|
print(
|
||
|
"Please install pycocoevalcap module using "
|
||
|
"pip install git+https://github.com/ronghanghu/coco-caption.git@python23" # noqa
|
||
|
)
|
||
|
raise
|
||
|
|
||
|
self.tokenizer = PTBTokenizer()
|
||
|
self.scorer = Bleu(4)
|
||
|
|
||
|
def eval_pred_list(self, pred_list):
|
||
|
# Create reference and hypotheses captions.
|
||
|
gts = {}
|
||
|
res = {}
|
||
|
for idx, entry in enumerate(pred_list):
|
||
|
gts[idx] = [{"caption": a} for a in entry["gt_answers"]]
|
||
|
res[idx] = [{"caption": entry["pred_answer"]}]
|
||
|
|
||
|
gts = self.tokenizer.tokenize(gts)
|
||
|
res = self.tokenizer.tokenize(res)
|
||
|
score, _ = self.scorer.compute_score(gts, res)
|
||
|
|
||
|
bleu4 = score[3] # score is (Bleu-1, Bleu-2, Bleu-3, Bleu-4)
|
||
|
return bleu4
|