feat: Add MVTech dataset evaluation function.

This commit is contained in:
xxnithicxx 2025-01-10 15:59:09 +07:00
parent 93da3d0f53
commit 2f90a5866a
4 changed files with 180 additions and 5 deletions

153
eval.py Normal file
View File

@ -0,0 +1,153 @@
import os
import re
import json
import argparse
import cv2
from collections import defaultdict
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from minigpt4.datasets.datasets.vqa_datasets import RefADEvalData
from minigpt4.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA
from minigpt4.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
from minigpt4.common.config import Config
def list_of_str(arg):
return list(map(str, arg.split(',')))
parser = eval_parser()
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
parser.add_argument("--resample", action="store_true", default=False, help="resample failed samples")
parser.add_argument("--res", type=float, default=100.0, help="resolution used in regresion")
args = parser.parse_args()
cfg = Config(args)
model, vis_processor = init_model(args)
conv_temp = CONV_VISION_minigptv2.copy()
conv_temp.system = ""
model.eval()
save_path = cfg.run_cfg.save_path
if 'mvtech' in args.dataset:
eval_file_path = cfg.evaluation_datasets_cfg["mvtech_ad"]["eval_file_path"]
batch_size = cfg.evaluation_datasets_cfg["mvtech_ad"]["batch_size"]
max_new_tokens = cfg.evaluation_datasets_cfg["mvtech_ad"]["max_new_tokens"]
mvtech_ad = []
# Adapt the data loading to the RefCOCO format
mvtech_ad_data_for_regression = []
for category in os.listdir(eval_file_path):
category_path = os.path.join(eval_file_path, category)
if os.path.isdir(category_path):
for split in ["test"]:
split_path = os.path.join(category_path, split)
if os.path.isdir(split_path):
for defect in os.listdir(split_path):
defect_path = os.path.join(split_path, defect)
for img_file in os.listdir(defect_path):
img_path = os.path.join(defect_path, img_file)
mask_path = os.path.join(category_path, "ground_truth", defect, img_file.replace(".png", "_mask.png"))
if os.path.exists(mask_path):
# Get bounding box from mask
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if contours:
x, y, w, h = cv2.boundingRect(contours[0])
bbox = [x, y, x + w, y + h]
img_id = f"{category}_{split}_{defect}_{img_file}"
mvtech_ad_data_for_regression.append({
"img_id": img_id,
"img_path": img_path,
"category": category,
"defect": defect,
"bbox": bbox,
"height": mask.shape[0], # Assuming mask dimensions match image
"width": mask.shape[1],
"sents": '[refer] give me the location of ' + defect.replace('_', ' ') + ' defect'
})
data = RefADEvalData(mvtech_ad_data_for_regression, vis_processor)
data = list(data)[:len(data)//10] # Limit to 10% of the data
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
minigpt4_predict = defaultdict(list)
resamples = []
for images, questions, img_ids in tqdm(eval_dataloader):
texts = prepare_texts(questions, conv_temp)
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, img_id, question in zip(answers, img_ids, questions):
answer = answer.replace("<unk>","").replace(" ","").strip()
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
if re.match(pattern, answer):
minigpt4_predict[img_id].append(answer)
else:
resamples.append({'img_id': img_id, 'sents': question.replace('[refer] give me the location of','').strip()})
if args.resample:
for i in range(20):
data = RefADEvalData(resamples, vis_processor)
resamples = []
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
for images, questions, img_ids in tqdm(eval_dataloader):
texts = prepare_texts(questions, conv_temp)
answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False)
for answer, img_id, question in zip(answers, img_ids, questions):
answer = answer.replace("<unk>","").replace(" ","").strip()
pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}'
if re.match(pattern, answer) or i == 4:
minigpt4_predict[img_id].append(answer)
else:
resamples.append({'img_id': img_id, 'sents': question.replace('[refer] give me the location of','').strip()})
if len(resamples) == 0:
break
# Save predictions
file_save_path = os.path.join(save_path, "mvtech_ad_regression.json")
with open(file_save_path, 'w') as f:
json.dump(minigpt4_predict, f)
# Calculate metrics
count = 0
total = len(mvtech_ad_data_for_regression)
res = args.res
for item in mvtech_ad_data_for_regression:
img_id = item['img_id']
bbox = item['bbox']
outputs = minigpt4_predict[img_id]
for output in outputs:
try:
integers = re.findall(r'\d+', output)
pred_bbox = [int(num) for num in integers]
height = item['height']
width = item['width']
pred_bbox[0] = pred_bbox[0] / res * width
pred_bbox[1] = pred_bbox[1] / res * height
pred_bbox[2] = pred_bbox[2] / res * width
pred_bbox[3] = pred_bbox[3] / res * height
gt_bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]
iou_score = computeIoU(pred_bbox, gt_bbox)
if iou_score > 0.5:
count += 1
except:
continue
print(f'MVTech AD (Regression):', count / total * 100, flush=True)

View File

@ -3,14 +3,13 @@ model:
model_type: pretrain model_type: pretrain
max_txt_len: 500 max_txt_len: 500
end_sym: "</s>" end_sym: "</s>"
low_resource: False low_resource: True
prompt_template: '[INST] {} [/INST]' prompt_template: '[INST] {} [/INST]'
llama_model: "" llama_model: "meta-llama/Llama-2-7b-chat-hf"
ckpt: "" ckpt: "./checkpoint_stage3.pth"
lora_r: 64 lora_r: 64
lora_alpha: 16 lora_alpha: 16
datasets: datasets:
cc_sbu_align: cc_sbu_align:
vis_processor: vis_processor:
@ -22,6 +21,10 @@ datasets:
name: "blip_caption" name: "blip_caption"
evaluation_datasets: evaluation_datasets:
mvtech_ad:
batch_size: 6
eval_file_path: /mnt/data/MVTEC
max_new_tokens: 20
refcoco: refcoco:
eval_file_path: /path/to/eval/annotation/path eval_file_path: /path/to/eval/annotation/path
img_path: /path/to/eval/image/path img_path: /path/to/eval/image/path
@ -71,7 +74,7 @@ evaluation_datasets:
run: run:
task: image_text_pretrain task: image_text_pretrain
name: minigptv2_evaluation name: minigptv2_evaluation
save_path: /path/to/save/folder_path save_path: ./outputs

3
evaluate.sh Normal file
View File

@ -0,0 +1,3 @@
export CUDA_VISIBLE_DEVICES=3
torchrun --master-port 16016 --nproc_per_node 1 eval.py \
--cfg-path /home/VLAI/thinhpn/STA/MiniGPT-4/eval_configs/minigptv2_benchmark_evaluation.yaml --dataset mvtech

View File

@ -11,6 +11,22 @@ import os
from minigpt4.datasets.datasets.base_dataset import BaseDataset from minigpt4.datasets.datasets.base_dataset import BaseDataset
class RefADEvalData(torch.utils.data.Dataset):
def __init__(self, loaded_data, vis_processor):
self.loaded_data = loaded_data
self.vis_processor = vis_processor
def __len__(self):
return len(self.loaded_data)
def __getitem__(self, idx):
data = self.loaded_data[idx]
sent = data["sents"]
img_id = data["img_id"]
image_path = data["img_path"]
image = Image.open(image_path).convert("RGB")
image = self.vis_processor(image)
return image, sent, img_id
class VQADataset(BaseDataset): class VQADataset(BaseDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_paths): def __init__(self, vis_processor, text_processor, vis_root, ann_paths):