mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 01:50:47 +00:00
129 lines
5.3 KiB
Python
129 lines
5.3 KiB
Python
import os
|
|
import re
|
|
import json
|
|
import argparse
|
|
from collections import defaultdict
|
|
import random
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from minigpt4.common.config import Config
|
|
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU
|
|
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
|
|
|
|
from minigpt4.datasets.datasets.coco_caption import RefCOCOEvalData
|
|
|
|
def list_of_str(arg):
|
|
return list(map(str, arg.split(',')))
|
|
|
|
parser = eval_parser()
|
|
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
|
|
parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco")
|
|
parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
|
|
args = parser.parse_args()
|
|
|
|
cfg = Config(args)
|
|
|
|
eval_dict = {'refcoco': ['val','testA','testB'],
|
|
'refcoco+': ['val','testA','testB'],
|
|
'refcocog': ['val','test']}
|
|
|
|
|
|
model, vis_processor = init_model(args)
|
|
model.eval()
|
|
CONV_VISION = CONV_VISION_minigptv2
|
|
conv_temp = CONV_VISION.copy()
|
|
conv_temp.system = ""
|
|
|
|
#
|
|
model.eval()
|
|
save_path = cfg.run_cfg.save_path
|
|
|
|
|
|
|
|
for dataset in args.dataset:
|
|
for split in eval_dict[dataset]:
|
|
|
|
eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"]
|
|
img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"]
|
|
batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"]
|
|
max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"]
|
|
|
|
with open(os.path.join(eval_file_path,f"{dataset}/{dataset}_{split}.json"), 'r') as f:
|
|
refcoco = json.load(f)
|
|
|
|
data = RefCOCOEvalData(refcoco, vis_processor, img_path)
|
|
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
|
|
minigpt4_predict = defaultdict(list)
|
|
resamples = []
|
|
|
|
for images, questions, img_ids in tqdm(eval_dataloader):
|
|
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
|
|
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
|
|
for answer, img_id, question in zip(answers, img_ids, questions):
|
|
answer = answer.replace("<unk>","").replace(" ","").strip()
|
|
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
|
|
if re.match(pattern, answer):
|
|
minigpt4_predict[img_id].append(answer)
|
|
else:
|
|
resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]})
|
|
if args.resample:
|
|
for i in range(20):
|
|
data = RefCOCOEvalData(resamples, vis_processor, img_path)
|
|
resamples = []
|
|
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
|
|
for images, questions, img_ids in tqdm(eval_dataloader):
|
|
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
|
|
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
|
|
for answer, img_id, question in zip(answers, img_ids, questions):
|
|
answer = answer.replace("<unk>","").replace(" ","").strip()
|
|
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
|
|
if re.match(pattern, answer) or i == 4:
|
|
minigpt4_predict[img_id].append(answer)
|
|
else:
|
|
resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]})
|
|
|
|
if len(resamples) == 0:
|
|
break
|
|
|
|
file_save_path = os.path.join(save_path,f"{args.dataset}_{split}.json")
|
|
with open(file_save_path,'w') as f:
|
|
json.dump(minigpt4_predict, f)
|
|
|
|
count=0
|
|
total=len(refcoco)
|
|
res=args.res
|
|
refcoco_dict = defaultdict()
|
|
for item in refcoco:
|
|
refcoco_dict[item['img_id']] = item
|
|
for img_id in refcoco_dict:
|
|
item = refcoco_dict[img_id]
|
|
bbox = item['bbox']
|
|
outputs = minigpt4_predict[img_id]
|
|
for output in outputs:
|
|
try:
|
|
integers = re.findall(r'\d+', output)
|
|
pred_bbox = [int(num) for num in integers]
|
|
height = item['height']
|
|
width = item['width']
|
|
pred_bbox[0] = pred_bbox[0] / res * width
|
|
pred_bbox[1] = pred_bbox[1] / res * height
|
|
pred_bbox[2] = pred_bbox[2] / res * width
|
|
pred_bbox[3] = pred_bbox[3] / res * height
|
|
|
|
gt_bbox = [0,0,0,0]
|
|
gt_bbox[0] = bbox[0]
|
|
gt_bbox[1] = bbox[1]
|
|
gt_bbox[2] = bbox[0] + bbox[2]
|
|
gt_bbox[3] = bbox[1] + bbox[3]
|
|
|
|
iou_score = computeIoU(pred_bbox, gt_bbox)
|
|
if iou_score > 0.5:
|
|
count+=1
|
|
except:
|
|
continue
|
|
|
|
print(f'{dataset} {split}:', count / total * 100, flush=True)
|