mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
feat: ✨ Add MVTech dataset evaluation function.
This commit is contained in:
parent
93da3d0f53
commit
2f90a5866a
153
eval.py
Normal file
153
eval.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import argparse
|
||||||
|
import cv2
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from minigpt4.datasets.datasets.vqa_datasets import RefADEvalData
|
||||||
|
from minigpt4.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA
|
||||||
|
from minigpt4.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval
|
||||||
|
|
||||||
|
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU
|
||||||
|
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
|
||||||
|
from minigpt4.common.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
def list_of_str(arg):
|
||||||
|
return list(map(str, arg.split(',')))
|
||||||
|
|
||||||
|
parser = eval_parser()
|
||||||
|
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
|
||||||
|
parser.add_argument("--resample", action="store_true", default=False, help="resample failed samples")
|
||||||
|
parser.add_argument("--res", type=float, default=100.0, help="resolution used in regresion")
|
||||||
|
args = parser.parse_args()
|
||||||
|
cfg = Config(args)
|
||||||
|
|
||||||
|
model, vis_processor = init_model(args)
|
||||||
|
conv_temp = CONV_VISION_minigptv2.copy()
|
||||||
|
conv_temp.system = ""
|
||||||
|
model.eval()
|
||||||
|
save_path = cfg.run_cfg.save_path
|
||||||
|
|
||||||
|
if 'mvtech' in args.dataset:
|
||||||
|
eval_file_path = cfg.evaluation_datasets_cfg["mvtech_ad"]["eval_file_path"]
|
||||||
|
batch_size = cfg.evaluation_datasets_cfg["mvtech_ad"]["batch_size"]
|
||||||
|
max_new_tokens = cfg.evaluation_datasets_cfg["mvtech_ad"]["max_new_tokens"]
|
||||||
|
|
||||||
|
mvtech_ad = []
|
||||||
|
|
||||||
|
# Adapt the data loading to the RefCOCO format
|
||||||
|
mvtech_ad_data_for_regression = []
|
||||||
|
for category in os.listdir(eval_file_path):
|
||||||
|
category_path = os.path.join(eval_file_path, category)
|
||||||
|
if os.path.isdir(category_path):
|
||||||
|
for split in ["test"]:
|
||||||
|
split_path = os.path.join(category_path, split)
|
||||||
|
if os.path.isdir(split_path):
|
||||||
|
for defect in os.listdir(split_path):
|
||||||
|
defect_path = os.path.join(split_path, defect)
|
||||||
|
for img_file in os.listdir(defect_path):
|
||||||
|
img_path = os.path.join(defect_path, img_file)
|
||||||
|
mask_path = os.path.join(category_path, "ground_truth", defect, img_file.replace(".png", "_mask.png"))
|
||||||
|
|
||||||
|
if os.path.exists(mask_path):
|
||||||
|
# Get bounding box from mask
|
||||||
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
if contours:
|
||||||
|
x, y, w, h = cv2.boundingRect(contours[0])
|
||||||
|
bbox = [x, y, x + w, y + h]
|
||||||
|
|
||||||
|
img_id = f"{category}_{split}_{defect}_{img_file}"
|
||||||
|
|
||||||
|
mvtech_ad_data_for_regression.append({
|
||||||
|
"img_id": img_id,
|
||||||
|
"img_path": img_path,
|
||||||
|
"category": category,
|
||||||
|
"defect": defect,
|
||||||
|
"bbox": bbox,
|
||||||
|
"height": mask.shape[0], # Assuming mask dimensions match image
|
||||||
|
"width": mask.shape[1],
|
||||||
|
"sents": '[refer] give me the location of ' + defect.replace('_', ' ') + ' defect'
|
||||||
|
})
|
||||||
|
|
||||||
|
data = RefADEvalData(mvtech_ad_data_for_regression, vis_processor)
|
||||||
|
data = list(data)[:len(data)//10] # Limit to 10% of the data
|
||||||
|
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
minigpt4_predict = defaultdict(list)
|
||||||
|
resamples = []
|
||||||
|
|
||||||
|
for images, questions, img_ids in tqdm(eval_dataloader):
|
||||||
|
texts = prepare_texts(questions, conv_temp)
|
||||||
|
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
|
||||||
|
for answer, img_id, question in zip(answers, img_ids, questions):
|
||||||
|
answer = answer.replace("<unk>","").replace(" ","").strip()
|
||||||
|
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
|
||||||
|
if re.match(pattern, answer):
|
||||||
|
minigpt4_predict[img_id].append(answer)
|
||||||
|
else:
|
||||||
|
resamples.append({'img_id': img_id, 'sents': question.replace('[refer] give me the location of','').strip()})
|
||||||
|
|
||||||
|
if args.resample:
|
||||||
|
for i in range(20):
|
||||||
|
data = RefADEvalData(resamples, vis_processor)
|
||||||
|
resamples = []
|
||||||
|
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
for images, questions, img_ids in tqdm(eval_dataloader):
|
||||||
|
texts = prepare_texts(questions, conv_temp)
|
||||||
|
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
|
||||||
|
for answer, img_id, question in zip(answers, img_ids, questions):
|
||||||
|
answer = answer.replace("<unk>","").replace(" ","").strip()
|
||||||
|
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
|
||||||
|
if re.match(pattern, answer) or i == 4:
|
||||||
|
minigpt4_predict[img_id].append(answer)
|
||||||
|
else:
|
||||||
|
resamples.append({'img_id': img_id, 'sents': question.replace('[refer] give me the location of','').strip()})
|
||||||
|
|
||||||
|
if len(resamples) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save predictions
|
||||||
|
file_save_path = os.path.join(save_path, "mvtech_ad_regression.json")
|
||||||
|
with open(file_save_path, 'w') as f:
|
||||||
|
json.dump(minigpt4_predict, f)
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
count = 0
|
||||||
|
total = len(mvtech_ad_data_for_regression)
|
||||||
|
res = args.res
|
||||||
|
for item in mvtech_ad_data_for_regression:
|
||||||
|
img_id = item['img_id']
|
||||||
|
bbox = item['bbox']
|
||||||
|
outputs = minigpt4_predict[img_id]
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
try:
|
||||||
|
integers = re.findall(r'\d+', output)
|
||||||
|
pred_bbox = [int(num) for num in integers]
|
||||||
|
height = item['height']
|
||||||
|
width = item['width']
|
||||||
|
pred_bbox[0] = pred_bbox[0] / res * width
|
||||||
|
pred_bbox[1] = pred_bbox[1] / res * height
|
||||||
|
pred_bbox[2] = pred_bbox[2] / res * width
|
||||||
|
pred_bbox[3] = pred_bbox[3] / res * height
|
||||||
|
|
||||||
|
gt_bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
|
||||||
|
|
||||||
|
iou_score = computeIoU(pred_bbox, gt_bbox)
|
||||||
|
if iou_score > 0.5:
|
||||||
|
count += 1
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f'MVTech AD (Regression):', count / total * 100, flush=True)
|
@ -3,14 +3,13 @@ model:
|
|||||||
model_type: pretrain
|
model_type: pretrain
|
||||||
max_txt_len: 500
|
max_txt_len: 500
|
||||||
end_sym: "</s>"
|
end_sym: "</s>"
|
||||||
low_resource: False
|
low_resource: True
|
||||||
prompt_template: '[INST] {} [/INST]'
|
prompt_template: '[INST] {} [/INST]'
|
||||||
llama_model: ""
|
llama_model: "meta-llama/Llama-2-7b-chat-hf"
|
||||||
ckpt: ""
|
ckpt: "./checkpoint_stage3.pth"
|
||||||
lora_r: 64
|
lora_r: 64
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
cc_sbu_align:
|
cc_sbu_align:
|
||||||
vis_processor:
|
vis_processor:
|
||||||
@ -22,6 +21,10 @@ datasets:
|
|||||||
name: "blip_caption"
|
name: "blip_caption"
|
||||||
|
|
||||||
evaluation_datasets:
|
evaluation_datasets:
|
||||||
|
mvtech_ad:
|
||||||
|
batch_size: 6
|
||||||
|
eval_file_path: /mnt/data/MVTEC
|
||||||
|
max_new_tokens: 20
|
||||||
refcoco:
|
refcoco:
|
||||||
eval_file_path: /path/to/eval/annotation/path
|
eval_file_path: /path/to/eval/annotation/path
|
||||||
img_path: /path/to/eval/image/path
|
img_path: /path/to/eval/image/path
|
||||||
@ -71,7 +74,7 @@ evaluation_datasets:
|
|||||||
run:
|
run:
|
||||||
task: image_text_pretrain
|
task: image_text_pretrain
|
||||||
name: minigptv2_evaluation
|
name: minigptv2_evaluation
|
||||||
save_path: /path/to/save/folder_path
|
save_path: ./outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
3
evaluate.sh
Normal file
3
evaluate.sh
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
export CUDA_VISIBLE_DEVICES=3
|
||||||
|
torchrun --master-port 16016 --nproc_per_node 1 eval.py \
|
||||||
|
--cfg-path /home/VLAI/thinhpn/STA/MiniGPT-4/eval_configs/minigptv2_benchmark_evaluation.yaml --dataset mvtech
|
@ -11,6 +11,22 @@ import os
|
|||||||
|
|
||||||
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
from minigpt4.datasets.datasets.base_dataset import BaseDataset
|
||||||
|
|
||||||
|
class RefADEvalData(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, loaded_data, vis_processor):
|
||||||
|
self.loaded_data = loaded_data
|
||||||
|
self.vis_processor = vis_processor
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.loaded_data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
data = self.loaded_data[idx]
|
||||||
|
sent = data["sents"]
|
||||||
|
img_id = data["img_id"]
|
||||||
|
image_path = data["img_path"]
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
image = self.vis_processor(image)
|
||||||
|
return image, sent, img_id
|
||||||
|
|
||||||
class VQADataset(BaseDataset):
|
class VQADataset(BaseDataset):
|
||||||
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
|
||||||
|
Loading…
Reference in New Issue
Block a user