66 lines
2.2 KiB
Python
66 lines
2.2 KiB
Python
|
import os
|
||
|
import argparse
|
||
|
import json
|
||
|
import re
|
||
|
|
||
|
from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
|
||
|
|
||
|
|
||
|
def get_args():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--annotation-file', type=str)
|
||
|
parser.add_argument('--result-file', type=str)
|
||
|
parser.add_argument('--result-dir', type=str)
|
||
|
return parser.parse_args()
|
||
|
|
||
|
|
||
|
def prompt_processor(prompt):
|
||
|
if prompt.startswith('OCR tokens: '):
|
||
|
pattern = r"Question: (.*?) Short answer:"
|
||
|
match = re.search(pattern, prompt, re.DOTALL)
|
||
|
question = match.group(1)
|
||
|
elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
|
||
|
if prompt.startswith('Reference OCR token:'):
|
||
|
question = prompt.split('\n')[1]
|
||
|
else:
|
||
|
question = prompt.split('\n')[0]
|
||
|
elif len(prompt.split('\n')) == 2:
|
||
|
question = prompt.split('\n')[0]
|
||
|
else:
|
||
|
assert False
|
||
|
|
||
|
return question.lower()
|
||
|
|
||
|
|
||
|
def eval_single(annotation_file, result_file):
|
||
|
experiment_name = os.path.splitext(os.path.basename(result_file))[0]
|
||
|
print(experiment_name)
|
||
|
annotations = json.load(open(annotation_file))['data']
|
||
|
annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
|
||
|
results = [json.loads(line) for line in open(result_file)]
|
||
|
|
||
|
pred_list = []
|
||
|
for result in results:
|
||
|
annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
|
||
|
pred_list.append({
|
||
|
"pred_answer": result['text'],
|
||
|
"gt_answers": annotation['answers'],
|
||
|
})
|
||
|
|
||
|
evaluator = TextVQAAccuracyEvaluator()
|
||
|
print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
args = get_args()
|
||
|
|
||
|
if args.result_file is not None:
|
||
|
eval_single(args.annotation_file, args.result_file)
|
||
|
|
||
|
if args.result_dir is not None:
|
||
|
for result_file in sorted(os.listdir(args.result_dir)):
|
||
|
if not result_file.endswith('.jsonl'):
|
||
|
print(f'Skipping {result_file}')
|
||
|
continue
|
||
|
eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
|