mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
420 lines
14 KiB
Python
420 lines
14 KiB
Python
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import logging
|
|
import json
|
|
import os
|
|
import torch.distributed as dist
|
|
from collections import defaultdict
|
|
from minigpt4.common.registry import registry
|
|
from minigpt4.tasks.base_task import BaseTask
|
|
import minigpt4.common.dist_utils as dist_utils
|
|
from minigpt4.common.logger import MetricLogger
|
|
from minigpt4.datasets.data_utils import prepare_sample
|
|
from minigpt4.common.dist_utils import is_dist_avail_and_initialized
|
|
from minigpt4.common.vqa_tools.vqa import VQA
|
|
from minigpt4.common.vqa_tools.vqa_eval import VQAEval
|
|
|
|
|
|
@registry.register_task("instruction_tuning")
|
|
class InstructionTask(BaseTask):
|
|
def __init__(
|
|
self,
|
|
num_beams,
|
|
max_len,
|
|
min_len,
|
|
evaluate,
|
|
num_ans_candidates,
|
|
inference_method="rank",
|
|
prompt="",
|
|
):
|
|
super().__init__()
|
|
|
|
self.num_beams = num_beams
|
|
self.max_len = max_len
|
|
self.min_len = min_len
|
|
|
|
self.evaluate = evaluate
|
|
self.inference_method = inference_method
|
|
self.num_ans_candidates = num_ans_candidates
|
|
self.prompt = prompt
|
|
|
|
self.answer_list = None
|
|
|
|
self.ques_files = defaultdict(dict)
|
|
self.anno_files = defaultdict(dict)
|
|
|
|
@classmethod
|
|
def setup_task(cls, cfg):
|
|
run_cfg = cfg.run_cfg
|
|
|
|
num_beams = run_cfg.get("num_beams", 3)
|
|
max_len = run_cfg.get("max_len", 30)
|
|
min_len = run_cfg.get("min_len", 1)
|
|
|
|
evaluate = run_cfg.get("evaluate", False)
|
|
|
|
inference_method = run_cfg.get("inference_method", "rank")
|
|
num_ans_candidates = run_cfg.get("num_ans_candidates", 128)
|
|
prompt = run_cfg.get("prompt", "")
|
|
|
|
return cls(
|
|
num_beams=num_beams,
|
|
max_len=max_len,
|
|
min_len=min_len,
|
|
evaluate=evaluate,
|
|
num_ans_candidates=num_ans_candidates,
|
|
inference_method=inference_method,
|
|
prompt=prompt,
|
|
)
|
|
|
|
def build_datasets(self, cfg):
|
|
datasets = super().build_datasets(cfg)
|
|
|
|
# get question file, annotation file and anwser list in COCO format
|
|
for dataset in datasets.values():
|
|
for split in dataset:
|
|
source = dataset[split].source
|
|
if (
|
|
hasattr(dataset[split], "coco_fmt_qust_file")
|
|
and dataset[split].coco_fmt_qust_file is not None
|
|
):
|
|
self.ques_files[split][source] = dataset[split].coco_fmt_qust_file
|
|
self.anno_files[split][source] = dataset[split].coco_fmt_anno_file
|
|
|
|
# try:
|
|
# self.answer_list = dataset[split].answer_list
|
|
# except AttributeError:
|
|
# # if answer_list is not provided, then set it to None
|
|
# pass
|
|
|
|
if len(self.ques_files) > 0:
|
|
assert len(self.ques_files) == len(
|
|
self.anno_files
|
|
), "Only support one split for evaluation."
|
|
|
|
return datasets
|
|
|
|
def valid_step(self, model, samples):
|
|
answers = model.predict_answers(
|
|
samples=samples,
|
|
answer_list=self.answer_list,
|
|
inference_method=self.inference_method,
|
|
num_beams=self.num_beams,
|
|
max_len=self.max_len,
|
|
min_len=self.min_len,
|
|
num_ans_candidates=self.num_ans_candidates,
|
|
prompt=self.prompt,
|
|
)
|
|
pred_qa_pairs = []
|
|
|
|
text_inputs = samples["text_input"]
|
|
|
|
sources = samples["source"]
|
|
source = samples["source"][0]
|
|
|
|
if source in ['vqav2','okvqa','gqa']:
|
|
sample_ids = [int(sample_id.item()) for sample_id in samples["question_id"]]
|
|
elif source in ['aokvqa']:
|
|
sample_ids = [sample_id for sample_id in samples["question_id"]]
|
|
elif source in ['coco_cap']:
|
|
sample_ids = samples["image_id"]
|
|
|
|
# For GQA
|
|
full_answers = samples.get("fullAnswer", ["" for i in range(len(sample_ids))])
|
|
gt_answers = samples.get("gt_answers", ["" for i in range(len(sample_ids))])
|
|
|
|
# For AOKVQA
|
|
choices = samples.get("choices", ["" for i in range(len(sample_ids))])
|
|
|
|
for answer, sample_id, text_input, full_answer, gt_answer, choice, source in zip(answers, sample_ids, text_inputs, full_answers, gt_answers, choices, sources):
|
|
pred_qa_pairs.append({
|
|
"question_id": sample_id,
|
|
"question": text_input,
|
|
"full_answer": full_answer,
|
|
"answer": answer,
|
|
"gt_ans": gt_answer,
|
|
"choice": choice,
|
|
"source": source})
|
|
return pred_qa_pairs
|
|
|
|
def evaluation(self, model, data_loader, cuda_enabled=True):
|
|
metric_logger = MetricLogger(delimiter=" ")
|
|
header = "Evaluation"
|
|
# TODO make it configurable
|
|
print_freq = 10
|
|
|
|
total_results = list()
|
|
for sub_data_loader in data_loader.loaders:
|
|
results = []
|
|
for samples in metric_logger.log_every(sub_data_loader, print_freq, header):
|
|
|
|
samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
|
|
eval_output = self.valid_step(model=model, samples=samples)
|
|
|
|
results.extend(eval_output)
|
|
|
|
total_results.append(results)
|
|
|
|
if is_dist_avail_and_initialized():
|
|
dist.barrier()
|
|
|
|
return total_results
|
|
|
|
|
|
def after_evaluation(self, val_result, split_name, **kwargs):
|
|
|
|
final_metrics = dict()
|
|
for i in range(len(val_result)):
|
|
source = val_result[i][0]["source"]
|
|
result_file = self.save_result(
|
|
val_result[i],
|
|
result_dir=registry.get_path("result_dir"),
|
|
filename=f"{split_name}_vqa_result_{source}",
|
|
remove_duplicate="question_id",
|
|
)
|
|
|
|
if source in ['vqav2','okvqa']:
|
|
try:
|
|
metrics = self._report_metrics_coco_vqa(result_file=result_file, split=split_name, source=source)
|
|
except Exception as e:
|
|
metrics = None
|
|
print(f"Report Metrics {source} Error: {e}")
|
|
elif source in ['gqa']:
|
|
try:
|
|
metrics = self._report_metrics_gqa(result_file=result_file, source=source)
|
|
except Exception as e:
|
|
metrics = None
|
|
print(f"Report Metrics {source} Error: {e}")
|
|
elif source in ['aokvqa']:
|
|
try:
|
|
metrics = self._report_metrics_aokvqa(result_file=result_file, source=source)
|
|
except Exception as e:
|
|
metrics = None
|
|
print(f"Report Metrics {source} Error: {e}")
|
|
elif source in ['coco_cap']:
|
|
try:
|
|
metrics = self._report_metrics_caption(result_file=result_file, split_name=split_name, source=source)
|
|
except Exception as e:
|
|
metrics = None
|
|
print(f"Report Metrics {source} Error: {e}")
|
|
else:
|
|
metrics = None
|
|
final_metrics[source] = metrics
|
|
|
|
try:
|
|
agg_metrics_lst = [v["agg_metrics"] for k,v in final_metrics.items()]
|
|
final_metrics["agg_metrics"] = sum(agg_metrics_lst)/len(agg_metrics_lst)
|
|
except Exception as e:
|
|
print("Calculate agg metrics error... ", e)
|
|
final_metrics = None
|
|
|
|
return final_metrics
|
|
|
|
@dist_utils.main_process
|
|
def _report_metrics_coco_vqa(self, result_file, split, source='vqav2'):
|
|
"""
|
|
Use official VQA evaluation script to report metrics.
|
|
"""
|
|
metrics = {}
|
|
|
|
if split in self.ques_files and split in self.anno_files:
|
|
vqa = VQA(self.anno_files[split][source], self.ques_files[split][source])
|
|
vqa_result = vqa.loadRes(
|
|
resFile=result_file, quesFile=self.ques_files[split][source]
|
|
)
|
|
|
|
# create vqaEval object by taking vqa and vqaRes
|
|
# n is precision of accuracy (number of places after decimal), default is 2
|
|
vqa_scorer = VQAEval(vqa, vqa_result, n=2)
|
|
logging.info("Start VQA evaluation.")
|
|
vqa_scorer.evaluate()
|
|
|
|
# print accuracies
|
|
overall_acc = vqa_scorer.accuracy["overall"]
|
|
metrics["agg_metrics"] = overall_acc
|
|
|
|
logging.info("Overall Accuracy is: %.02f\n" % overall_acc)
|
|
logging.info("Per Answer Type Accuracy is the following:")
|
|
|
|
for ans_type in vqa_scorer.accuracy["perAnswerType"]:
|
|
logging.info(
|
|
"%s : %.02f"
|
|
% (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type])
|
|
)
|
|
metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type]
|
|
|
|
with open(
|
|
os.path.join(registry.get_path("output_dir"), f"evaluate_{source}.txt"), "a"
|
|
) as f:
|
|
f.write(json.dumps(metrics) + "\n")
|
|
|
|
return metrics
|
|
|
|
@dist_utils.main_process
|
|
def _report_metrics_aokvqa(self, result_file, source='aokvqa'):
|
|
"""
|
|
Validation of aokvqa
|
|
"""
|
|
# measuring accuracy compared to answer
|
|
results = json.load(open(result_file, "r"))
|
|
acc = []
|
|
vqa_tool = VQAEval()
|
|
|
|
for res in results:
|
|
|
|
gt_ans = res["choice"]
|
|
pred = res["answer"]
|
|
|
|
pred = vqa_tool.processPunctuation(pred)
|
|
pred = vqa_tool.processDigitArticle(pred)
|
|
|
|
# vqa_acc = 1 if pred == gt_ans else 0
|
|
vqa_acc = 1 if pred in gt_ans else 0
|
|
|
|
acc.append(vqa_acc)
|
|
|
|
accuracy = sum(acc) / len(acc) * 100
|
|
metrics = {"agg_metrics": accuracy, "acc": accuracy}
|
|
|
|
with open(
|
|
os.path.join(registry.get_path("output_dir"), f"evaluate_{source}.txt"), "a"
|
|
) as f:
|
|
f.write(json.dumps(metrics) + "\n")
|
|
|
|
logging.info(metrics)
|
|
|
|
return metrics
|
|
|
|
|
|
@dist_utils.main_process
|
|
def _report_metrics_gqa(self, result_file, source='gqa'):
|
|
"""
|
|
Validation of GQA
|
|
"""
|
|
# measuring accuracy compared to answer
|
|
results = json.load(open(result_file, "r"))
|
|
acc = []
|
|
vqa_tool = VQAEval()
|
|
|
|
for res in results:
|
|
# if res["gt_ans"] is None:
|
|
# prepare test results for leaderboard evaluation
|
|
# self._save_result_leaderboard(results)
|
|
# return
|
|
|
|
gt_ans = res["gt_ans"]
|
|
pred = res["answer"]
|
|
|
|
# if self.inference_method == "generate":
|
|
pred = vqa_tool.processPunctuation(pred)
|
|
pred = vqa_tool.processDigitArticle(pred)
|
|
|
|
# vqa_acc = 1 if pred == gt_ans else 0
|
|
vqa_acc = 1 if gt_ans in pred else 0
|
|
|
|
acc.append(vqa_acc)
|
|
|
|
accuracy = sum(acc) / len(acc) * 100
|
|
metrics = {"agg_metrics": accuracy, "acc": accuracy}
|
|
|
|
with open(
|
|
os.path.join(registry.get_path("output_dir"), f"evaluate_{source}.txt"), "a"
|
|
) as f:
|
|
f.write(json.dumps(metrics) + "\n")
|
|
|
|
logging.info(metrics)
|
|
|
|
return metrics
|
|
|
|
@dist_utils.main_process
|
|
def _report_metrics_caption(self, result_file, split_name, source='coco_cap'):
|
|
"""
|
|
Use official COCO Cap evaluation script to report metrics.
|
|
"""
|
|
coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt")
|
|
coco_val = coco_caption_eval(coco_gt_root, result_file, split_name)
|
|
|
|
agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"]
|
|
log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}}
|
|
|
|
with open(
|
|
os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a"
|
|
) as f:
|
|
f.write(json.dumps(log_stats) + "\n")
|
|
|
|
coco_res = {k: v for k, v in coco_val.eval.items()}
|
|
coco_res["agg_metrics"] = agg_metrics
|
|
|
|
return coco_res
|
|
|
|
from collections import defaultdict
|
|
from pycocoevalcap.eval import COCOEvalCap
|
|
class COCO_Annotation:
|
|
def __init__(self, annotation_file):
|
|
self.coco_cn_file = annotation_file
|
|
self.imgToAnns = self.build_imgToAnns()
|
|
|
|
def build_imgToAnns(self):
|
|
imgToAnns = defaultdict(list)
|
|
with open(self.coco_cn_file, "r", encoding="UTF-8") as fin:
|
|
for line in fin:
|
|
line = line.strip()
|
|
temp = eval(line)
|
|
annotations = temp['annotations']
|
|
for ann in annotations:
|
|
image_id = str(ann['image_id']).zfill(6)
|
|
imgToAnns[image_id].append({'image_id':image_id,'caption':ann['caption'],'image': ann['image_id']})
|
|
return imgToAnns
|
|
|
|
def getImgIds(self):
|
|
return self.imgToAnns.keys()
|
|
|
|
class COCO_Result:
|
|
def __init__(self,result_file):
|
|
self.coco_cn_file = result_file
|
|
self.imgToAnns = self.build_imgToAnns()
|
|
|
|
def build_imgToAnns(self):
|
|
imgToAnns = dict()
|
|
data = json.load(open(self.coco_cn_file, "r"))
|
|
for d in data:
|
|
tmp = {
|
|
'image_id':d['question_id'][-6:],
|
|
'caption':d['answer']
|
|
}
|
|
imgToAnns[d['question_id'][-6:]] = [tmp]
|
|
return imgToAnns
|
|
|
|
def coco_caption_eval(coco_gt_root, results_file, split_name):
|
|
files = {
|
|
"val":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val_gt.json",
|
|
"test":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test_gt.json"
|
|
}
|
|
|
|
# create coco object and coco_result object
|
|
annotation_file = files[split_name]
|
|
coco = COCO_Annotation(annotation_file)
|
|
coco_result = COCO_Result(results_file)
|
|
|
|
# create coco_eval object by taking coco and coco_result
|
|
coco_eval = COCOEvalCap(coco, coco_result)
|
|
|
|
# evaluate on a subset of images by setting
|
|
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
|
# please remove this line when evaluating the full validation set
|
|
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
|
|
|
# evaluate results
|
|
# SPICE will take a few minutes the first time, but speeds up due to caching
|
|
coco_eval.evaluate()
|
|
|
|
# print output evaluation scores
|
|
for metric, score in coco_eval.eval.items():
|
|
print(f"{metric}: {score:.3f}")
|
|
|
|
return coco_eval |