mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-09 04:20:46 +00:00
Add files via upload
This commit is contained in:
parent
999bfa455d
commit
31d38c3e5c
@ -1 +1,16 @@
|
||||
#!/bin/bash --login
|
||||
cfg_path=eval_configs/minigpt4_llama2_eval.yaml
|
||||
CKPT=YOUR_CKPT_PATH
|
||||
NAME=EXP_NAME
|
||||
IMG_PATH=YOUR_IMG_PATH
|
||||
EVAL_FILE_PATH=YOUR_EVAL_FILE_PATH
|
||||
|
||||
torchrun --nproc_per_node 1 eval_ref.py --name ${NAME} \
|
||||
--cfg-path ${cfg_path} \
|
||||
--ckpt ${CKPT} --dataset refcoco,refcoco+,refcocog --lora_r 64 --lora_alpha 16 \
|
||||
--batch_size 64 --max_new_tokens 20 --resample --img_path ${IMG_PATH} --eval_file_path ${EVAL_FILE_PATH}
|
||||
|
||||
torchrun --nproc_per_node 1 eval_vqa.py --name ${NAME} \
|
||||
--cfg-path ${cfg_path} \
|
||||
--ckpt ${CKPT} --split val,test --dataset okvqa,vizwiz,aokvqa,iconqa,gqa,vsr,hm --lora_r 64 --lora_alpha 16 \
|
||||
--batch_size 32 --max_new_tokens 20 --resample
|
123
eval_scripts/eval_ref.py
Normal file
123
eval_scripts/eval_ref.py
Normal file
@ -0,0 +1,123 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU
|
||||
from minigpt4.conversation.conversation import CONV_VISION
|
||||
|
||||
from minigpt4.datasets.datasets.coco_caption import RefCOCOEvalData
|
||||
|
||||
def list_of_str(arg):
|
||||
return list(map(str, arg.split(',')))
|
||||
|
||||
parser = eval_parser()
|
||||
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
|
||||
parser.add_argument("--split", type=list_of_str, default='test', help="dataset to evaluate")
|
||||
parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco")
|
||||
parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
|
||||
parser.add_argument("--img_path", type=str)
|
||||
parser.add_argument("--eval_file_path", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args.ckpt)
|
||||
print(args.name)
|
||||
|
||||
eval_dict = {'refcoco': args.split,
|
||||
'refcoco+': args.split,
|
||||
'refcocog': args.split}
|
||||
|
||||
model, vis_processor = init_model(args)
|
||||
model.eval()
|
||||
conv_temp = CONV_VISION.copy()
|
||||
conv_temp.system = ""
|
||||
|
||||
model.eval()
|
||||
img_path=f'{args.img_path}/COCO/cocoapi/data/2017/images/jpeg/train'
|
||||
|
||||
for dataset in args.dataset:
|
||||
for split in eval_dict[dataset]:
|
||||
with open(f'{args.eval_file_path}/{dataset}/{dataset}_{split}.json', 'r') as f:
|
||||
refcoco = json.load(f)
|
||||
|
||||
data = RefCOCOEvalData(refcoco, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
minigpt4_predict = defaultdict(list)
|
||||
|
||||
resamples = []
|
||||
|
||||
for images, questions, img_ids in tqdm(eval_dataloader):
|
||||
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
for answer, img_id, question in zip(answers, img_ids, questions):
|
||||
answer = answer.replace("<unk>","").replace(" ","").strip()
|
||||
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
|
||||
if re.match(pattern, answer):
|
||||
minigpt4_predict[img_id].append(answer)
|
||||
else:
|
||||
resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] where is','').replace('?','').strip()]})
|
||||
|
||||
if args.resample:
|
||||
for i in range(20):
|
||||
data = RefCOCOEvalData(resamples, vis_processor, img_path)
|
||||
resamples = []
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
for images, questions, img_ids in tqdm(eval_dataloader):
|
||||
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
for answer, img_id, question in zip(answers, img_ids, questions):
|
||||
answer = answer.replace("<unk>","").replace(" ","").strip()
|
||||
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
|
||||
if re.match(pattern, answer) or i == 4:
|
||||
minigpt4_predict[img_id].append(answer)
|
||||
else:
|
||||
resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] where is','').replace('?','').strip()]})
|
||||
|
||||
if len(resamples) == 0:
|
||||
break
|
||||
|
||||
with open(f'results/{args.name}_{dataset}_{split}.json','w') as f:
|
||||
json.dump(minigpt4_predict, f)
|
||||
|
||||
count=0
|
||||
total=len(refcoco)
|
||||
res=args.res
|
||||
refcoco_dict = defaultdict()
|
||||
for item in refcoco:
|
||||
refcoco_dict[item['img_id']] = item
|
||||
for img_id in refcoco_dict:
|
||||
item = refcoco_dict[img_id]
|
||||
bbox = item['bbox']
|
||||
outputs = minigpt4_predict[img_id]
|
||||
for output in outputs:
|
||||
try:
|
||||
integers = re.findall(r'\d+', output)
|
||||
pred_bbox = [int(num) for num in integers]
|
||||
height = item['height']
|
||||
width = item['width']
|
||||
pred_bbox[0] = pred_bbox[0] / res * width
|
||||
pred_bbox[1] = pred_bbox[1] / res * height
|
||||
pred_bbox[2] = pred_bbox[2] / res * width
|
||||
pred_bbox[3] = pred_bbox[3] / res * height
|
||||
|
||||
gt_bbox = [0,0,0,0]
|
||||
gt_bbox[0] = bbox[0]
|
||||
gt_bbox[1] = bbox[1]
|
||||
gt_bbox[2] = bbox[0] + bbox[2]
|
||||
gt_bbox[3] = bbox[1] + bbox[3]
|
||||
|
||||
iou_score = computeIoU(pred_bbox, gt_bbox)
|
||||
if iou_score > 0.5:
|
||||
count+=1
|
||||
except:
|
||||
continue
|
||||
|
||||
print(f'{dataset} {split}:', count / total * 100, flush=True)
|
286
eval_scripts/eval_vqa.py
Normal file
286
eval_scripts/eval_vqa.py
Normal file
@ -0,0 +1,286 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
|
||||
from minigpt4.datasets.datasets.vqa_datasets import OKVQAEvalData,VizWizEvalData,AOKVQADAEvalData,AOKVQAMCEvalData,IconQAEvalData,GQAEvalData,VSREvalData,HMEvalData
|
||||
|
||||
from minigpt4.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA
|
||||
from minigpt4.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval
|
||||
|
||||
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
|
||||
from minigpt4.conversation.conversation import CONV_VISION
|
||||
import random
|
||||
|
||||
|
||||
def list_of_str(arg):
|
||||
return list(map(str, arg.split(',')))
|
||||
|
||||
|
||||
parser = eval_parser()
|
||||
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
|
||||
parser.add_argument("--split", type=list_of_str, default='testB', help="dataset split to evaluate")
|
||||
parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
|
||||
parser.add_argument("--img_path", type=str)
|
||||
parser.add_argument("--eval_file_path", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args.ckpt)
|
||||
print(args.name)
|
||||
|
||||
model, vis_processor = init_model(args)
|
||||
conv_temp = CONV_VISION.copy()
|
||||
conv_temp.system = ""
|
||||
|
||||
model.eval()
|
||||
|
||||
os.makedirs('results', exist_ok=True)
|
||||
|
||||
if 'okvqa' in args.dataset:
|
||||
img_path=f'{args.img_path}/COCO/cocoapi/data/2017/images/jpeg/train'
|
||||
with open(f'{args.eval_file_path}/okvqa/test_split.json', 'r') as f:
|
||||
ok_vqa_test_split = json.load(f)
|
||||
|
||||
data = OKVQAEvalData(ok_vqa_test_split, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
minigpt4_predict = []
|
||||
|
||||
resamples = []
|
||||
|
||||
for images, questions, question_ids, img_ids in eval_dataloader:
|
||||
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
|
||||
for answer, question_id, question, img_id in zip(answers, question_ids, questions, img_ids):
|
||||
result = dict()
|
||||
answer = answer.lower().replace('<unk>','').strip()
|
||||
result['answer'] = answer
|
||||
result['question_id'] = int(question_id)
|
||||
if answer == "":
|
||||
resamples.append({'image_id': img_id, 'question_id':question_id, 'question': [question.replace('[vqa] Based on the image, respond to this question with a short answer:','').strip()]})
|
||||
else:
|
||||
minigpt4_predict.append(result)
|
||||
|
||||
if args.resample:
|
||||
for i in range(20):
|
||||
data = OKVQAEvalData(resamples, vis_processor, img_path)
|
||||
resamples = []
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
for images, questions, question_ids, img_ids in eval_dataloader:
|
||||
texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
for answer, question_id, question in zip(answers, question_ids, questions):
|
||||
result = dict()
|
||||
answer = answer.lower().replace('<unk>','').strip()
|
||||
result['answer'] = answer
|
||||
result['question_id'] = int(question_id)
|
||||
minigpt4_predict.append(result)
|
||||
if answer == "":
|
||||
resamples.append({'image_id': img_id, 'question_id':question_id, 'question': [question.replace('[vqa] Based on the image, respond to this question with a short answer:','').strip()]})
|
||||
else:
|
||||
minigpt4_predict.append(result)
|
||||
if len(resamples) == 0:
|
||||
break
|
||||
|
||||
save_path=f'results_correct/{args.name}_okvqa.json'
|
||||
with open(save_path,'w') as f:
|
||||
json.dump(minigpt4_predict, f)
|
||||
|
||||
annFile =f'{args.eval_file_path}/ok_vqa/mscoco_val2014_annotations_clean.json'
|
||||
quesFile =f'{args.eval_file_path}/ok_vqa/OpenEnded_mscoco_val2014_questions_clean.json'
|
||||
|
||||
vqa = VQA(annFile, quesFile)
|
||||
vqaRes = vqa.loadRes(save_path, quesFile)
|
||||
|
||||
vqaEval = VQAEval(vqa, vqaRes, n=2)
|
||||
|
||||
vqaEval.evaluate()
|
||||
|
||||
print ("Overall OKVQA Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']), flush=True)
|
||||
|
||||
if 'vizwiz' in args.dataset:
|
||||
img_path=f'{args.img_path}/vizwiz/val'
|
||||
vizwiz = json.load(open(f'{args.eval_file_path}/vizwiz/val.json', 'r'))
|
||||
|
||||
data = VizWizEvalData(vizwiz, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
minigpt4_predict = []
|
||||
total_acc = []
|
||||
for images, texts, gt_answers in tqdm(eval_dataloader):
|
||||
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
|
||||
with torch.no_grad():
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
|
||||
for answer, gt_answer in zip(answers, gt_answers):
|
||||
result = dict()
|
||||
result['answer'] = answer.replace('<unk>','').strip()
|
||||
minigpt4_predict.append(result)
|
||||
count=0
|
||||
gt_answer = gt_answer.split('_')
|
||||
for gt in gt_answer:
|
||||
if gt.lower() == answer.lower():
|
||||
count += 1
|
||||
acc = min(count/3.0, 1.0)
|
||||
total_acc.append(acc)
|
||||
|
||||
save_path=f'results/{args.name}_vizwiz.json'
|
||||
with open(save_path,'w') as f:
|
||||
json.dump(minigpt4_predict, f)
|
||||
|
||||
print('vizwiz Acc: ', np.average(total_acc)* 100.0, flush=True)
|
||||
|
||||
if 'aokvqa' in args.dataset:
|
||||
img_path=f'{args.img_path}/aokvqa/images'
|
||||
|
||||
for split in args.split:
|
||||
with open(f'{args.eval_file_path}/aokvqa/annotations/aokvqa_v1p0_{split}.json','r') as f:
|
||||
aokvqa_v1p0 = json.load(f)
|
||||
|
||||
data = AOKVQADAEvalData(aokvqa_v1p0, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
minigpt4_predict = defaultdict(dict)
|
||||
|
||||
for images, texts, question_ids in tqdm(eval_dataloader):
|
||||
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
|
||||
for answer, question_id in zip(answers, question_ids):
|
||||
minigpt4_predict[question_id]['direct_answer'] = answer.lower().replace('<unk>','').strip()
|
||||
|
||||
data = AOKVQAMCEvalData(aokvqa_v1p0, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
for images, texts, question_ids, answers in tqdm(eval_dataloader):
|
||||
instructions = ["[INST] <Img><ImageHere></Img> {} [/INST]".format(text) for text in texts]
|
||||
answer_ranks = model.multi_select(images, instructions, answers)
|
||||
candidates = [list(x) for x in zip(*answers)]
|
||||
for idx, question_id in enumerate(question_ids):
|
||||
minigpt4_predict[question_id]['multiple_choice'] = candidates[idx][answer_ranks[idx][0]]
|
||||
|
||||
save_path=f'results/{args.name}_a_okvqa_{split}.json'
|
||||
with open(save_path,'w') as f:
|
||||
json.dump(minigpt4_predict, f)
|
||||
|
||||
os.chdir('minigpt4/common/vqa_tools/aokvqa')
|
||||
print(os.system(f'python evaluation/eval_predictions.py --aokvqa-dir {args.eval_file_path}/aokvqa/annotations --split {split} --preds ../../../../{save_path}'), flush=True)
|
||||
os.chdir('../../../../')
|
||||
|
||||
if 'iconqa' in args.dataset:
|
||||
iconqa_text_val = json.load(open(f'{eval_file_path}/iconqa/choose_text_val.json','r'))
|
||||
img_path = f'{args.img_path}/iconqa/val/choose_txt'
|
||||
data = IconQAEvalData(iconqa_text_val, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
|
||||
count = 0
|
||||
for images, texts, candidates, answers in tqdm(eval_dataloader):
|
||||
candidates = [candidate.split('_') for candidate in candidates]
|
||||
num_cand = [len(candidate) for candidate in candidates]
|
||||
for candidate in candidates:
|
||||
candidate.extend(['none'] * (max(num_cand) - len(candidate)))
|
||||
candidates = [list(x) for x in zip(*candidates)]
|
||||
instructions = ["[INST] <Img><ImageHere></Img> {} [/INST]".format(text) for text in texts]
|
||||
answer_ranks = model.multi_select(images, instructions, candidates, num_cand=num_cand)
|
||||
for idx, answer in enumerate(answers):
|
||||
if answer_ranks[idx][0] == answer:
|
||||
count += 1
|
||||
|
||||
print('iconqa Acc: ', count / len(iconqa_text_val) * 100.0, flush=True)
|
||||
|
||||
|
||||
if 'gqa' in args.dataset:
|
||||
img_path = f'{args.img_path}/gqa/images/val'
|
||||
gqa = json.load(open(f'{args.eval_file_path}/gqa/annotations/testdev_balanced_questions.json', 'r'))
|
||||
data = GQAEvalData(gqa, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
count=0
|
||||
total=0
|
||||
minigpt4_predict = []
|
||||
for images, texts, labels in tqdm(eval_dataloader):
|
||||
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
|
||||
for answer, label in zip(answers, labels):
|
||||
result = dict()
|
||||
result['pred'] = answer.lower().replace('<unk>','').strip()
|
||||
result['gt'] = label
|
||||
minigpt4_predict.append(result)
|
||||
if answer.lower() == label:
|
||||
count+=1
|
||||
total+=1
|
||||
print('gqa val:', count / total * 100, flush=True)
|
||||
|
||||
save_path=f'results/{args.name}_gqa.json'
|
||||
with open(save_path,'w') as f:
|
||||
json.dump(minigpt4_predict, f)
|
||||
|
||||
if 'vsr' in args.dataset:
|
||||
annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test')
|
||||
img_path = f'{args.img_path}/vsr/images'
|
||||
data = VSREvalData(annotation, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
count=0
|
||||
total=0
|
||||
|
||||
minigpt4_predict = []
|
||||
|
||||
for images, texts, labels in tqdm(eval_dataloader):
|
||||
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
|
||||
# print("texts",texts)
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
|
||||
for answer, label in zip(answers, labels):
|
||||
print(answer)
|
||||
result = dict()
|
||||
result['pred'] = answer.replace('<unk>','').strip()
|
||||
result['gt'] = label
|
||||
minigpt4_predict.append(result)
|
||||
if answer.lower() == label.lower():
|
||||
count+=1
|
||||
total+=1
|
||||
print('vsr test:', count / total * 100, flush=True)
|
||||
save_path=f'results/{args.name}_vsr.json'
|
||||
with open(save_path,'w') as f:
|
||||
json.dump(minigpt4_predict, f)
|
||||
|
||||
if 'hm' in args.dataset:
|
||||
img_path = f'{args.img_path}/hateful_meme'
|
||||
annotation = []
|
||||
with open(f'{args.eval_file_path}/hateful_meme/dev.jsonl', 'r') as jsonl_file:
|
||||
for line in jsonl_file:
|
||||
json_obj = json.loads(line)
|
||||
annotation.append(json_obj)
|
||||
|
||||
data = HMEvalData(annotation, vis_processor, img_path)
|
||||
eval_dataloader = DataLoader(data, batch_size=args.batch_size, shuffle=False)
|
||||
count=0
|
||||
total=0
|
||||
|
||||
minigpt4_predict = []
|
||||
|
||||
for images, texts, labels in tqdm(eval_dataloader):
|
||||
texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template
|
||||
answers = model.generate(images, texts, max_new_tokens=args.max_new_tokens, do_sample=False)
|
||||
|
||||
for answer, label in zip(answers, labels):
|
||||
result = dict()
|
||||
answer = 1 if answer.lower().__contains__('yes') else 0
|
||||
result['pred'] = int(str(answer).replace('<unk>','').strip())
|
||||
result['gt'] = int(label)
|
||||
minigpt4_predict.append(result)
|
||||
if answer == label:
|
||||
count+=1
|
||||
total+=1
|
||||
print('hm val:', count / total * 100, flush=True)
|
||||
|
||||
save_path=f'results/{args.name}_hm.json'
|
||||
with open(save_path,'w') as f:
|
||||
json.dump(minigpt4_predict, f)
|
Loading…
Reference in New Issue
Block a user