From 2f90a5866a6444e6cd375d694deb52dd5cb96321 Mon Sep 17 00:00:00 2001 From: xxnithicxx Date: Fri, 10 Jan 2025 15:59:09 +0700 Subject: [PATCH] feat: :sparkles: Add MVTech dataset evaluation function. --- eval.py | 153 ++++++++++++++++++ .../minigptv2_benchmark_evaluation.yaml | 13 +- evaluate.sh | 3 + minigpt4/datasets/datasets/vqa_datasets.py | 16 ++ 4 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 eval.py create mode 100644 evaluate.sh diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..633e3d6 --- /dev/null +++ b/eval.py @@ -0,0 +1,153 @@ +import os +import re +import json +import argparse +import cv2 +from collections import defaultdict + +import numpy as np +from PIL import Image +from tqdm import tqdm +import torch +from torch.utils.data import DataLoader +from datasets import load_dataset + +from minigpt4.datasets.datasets.vqa_datasets import RefADEvalData +from minigpt4.common.vqa_tools.VQA.PythonHelperTools.vqaTools.vqa import VQA +from minigpt4.common.vqa_tools.VQA.PythonEvaluationTools.vqaEvaluation.vqaEval import VQAEval + +from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU +from minigpt4.conversation.conversation import CONV_VISION_minigptv2 +from minigpt4.common.config import Config + + +def list_of_str(arg): + return list(map(str, arg.split(','))) + +parser = eval_parser() +parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate") +parser.add_argument("--resample", action="store_true", default=False, help="resample failed samples") +parser.add_argument("--res", type=float, default=100.0, help="resolution used in regresion") +args = parser.parse_args() +cfg = Config(args) + +model, vis_processor = init_model(args) +conv_temp = CONV_VISION_minigptv2.copy() +conv_temp.system = "" +model.eval() +save_path = cfg.run_cfg.save_path + +if 'mvtech' in args.dataset: + eval_file_path = cfg.evaluation_datasets_cfg["mvtech_ad"]["eval_file_path"] + batch_size = cfg.evaluation_datasets_cfg["mvtech_ad"]["batch_size"] + max_new_tokens = cfg.evaluation_datasets_cfg["mvtech_ad"]["max_new_tokens"] + + mvtech_ad = [] + + # Adapt the data loading to the RefCOCO format + mvtech_ad_data_for_regression = [] + for category in os.listdir(eval_file_path): + category_path = os.path.join(eval_file_path, category) + if os.path.isdir(category_path): + for split in ["test"]: + split_path = os.path.join(category_path, split) + if os.path.isdir(split_path): + for defect in os.listdir(split_path): + defect_path = os.path.join(split_path, defect) + for img_file in os.listdir(defect_path): + img_path = os.path.join(defect_path, img_file) + mask_path = os.path.join(category_path, "ground_truth", defect, img_file.replace(".png", "_mask.png")) + + if os.path.exists(mask_path): + # Get bounding box from mask + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if contours: + x, y, w, h = cv2.boundingRect(contours[0]) + bbox = [x, y, x + w, y + h] + + img_id = f"{category}_{split}_{defect}_{img_file}" + + mvtech_ad_data_for_regression.append({ + "img_id": img_id, + "img_path": img_path, + "category": category, + "defect": defect, + "bbox": bbox, + "height": mask.shape[0], # Assuming mask dimensions match image + "width": mask.shape[1], + "sents": '[refer] give me the location of ' + defect.replace('_', ' ') + ' defect' + }) + + data = RefADEvalData(mvtech_ad_data_for_regression, vis_processor) + data = list(data)[:len(data)//10] # Limit to 10% of the data + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + + minigpt4_predict = defaultdict(list) + resamples = [] + + for images, questions, img_ids in tqdm(eval_dataloader): + texts = prepare_texts(questions, conv_temp) + answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) + for answer, img_id, question in zip(answers, img_ids, questions): + answer = answer.replace("","").replace(" ","").strip() + pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}' + if re.match(pattern, answer): + minigpt4_predict[img_id].append(answer) + else: + resamples.append({'img_id': img_id, 'sents': question.replace('[refer] give me the location of','').strip()}) + + if args.resample: + for i in range(20): + data = RefADEvalData(resamples, vis_processor) + resamples = [] + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + + for images, questions, img_ids in tqdm(eval_dataloader): + texts = prepare_texts(questions, conv_temp) + answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) + for answer, img_id, question in zip(answers, img_ids, questions): + answer = answer.replace("","").replace(" ","").strip() + pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}' + if re.match(pattern, answer) or i == 4: + minigpt4_predict[img_id].append(answer) + else: + resamples.append({'img_id': img_id, 'sents': question.replace('[refer] give me the location of','').strip()}) + + if len(resamples) == 0: + break + + # Save predictions + file_save_path = os.path.join(save_path, "mvtech_ad_regression.json") + with open(file_save_path, 'w') as f: + json.dump(minigpt4_predict, f) + + # Calculate metrics + count = 0 + total = len(mvtech_ad_data_for_regression) + res = args.res + for item in mvtech_ad_data_for_regression: + img_id = item['img_id'] + bbox = item['bbox'] + outputs = minigpt4_predict[img_id] + + for output in outputs: + try: + integers = re.findall(r'\d+', output) + pred_bbox = [int(num) for num in integers] + height = item['height'] + width = item['width'] + pred_bbox[0] = pred_bbox[0] / res * width + pred_bbox[1] = pred_bbox[1] / res * height + pred_bbox[2] = pred_bbox[2] / res * width + pred_bbox[3] = pred_bbox[3] / res * height + + gt_bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]] + + iou_score = computeIoU(pred_bbox, gt_bbox) + if iou_score > 0.5: + count += 1 + except: + continue + + print(f'MVTech AD (Regression):', count / total * 100, flush=True) \ No newline at end of file diff --git a/eval_configs/minigptv2_benchmark_evaluation.yaml b/eval_configs/minigptv2_benchmark_evaluation.yaml index c0e3a26..bb3a6a7 100644 --- a/eval_configs/minigptv2_benchmark_evaluation.yaml +++ b/eval_configs/minigptv2_benchmark_evaluation.yaml @@ -3,14 +3,13 @@ model: model_type: pretrain max_txt_len: 500 end_sym: "" - low_resource: False + low_resource: True prompt_template: '[INST] {} [/INST]' - llama_model: "" - ckpt: "" + llama_model: "meta-llama/Llama-2-7b-chat-hf" + ckpt: "./checkpoint_stage3.pth" lora_r: 64 lora_alpha: 16 - datasets: cc_sbu_align: vis_processor: @@ -22,6 +21,10 @@ datasets: name: "blip_caption" evaluation_datasets: + mvtech_ad: + batch_size: 6 + eval_file_path: /mnt/data/MVTEC + max_new_tokens: 20 refcoco: eval_file_path: /path/to/eval/annotation/path img_path: /path/to/eval/image/path @@ -71,7 +74,7 @@ evaluation_datasets: run: task: image_text_pretrain name: minigptv2_evaluation - save_path: /path/to/save/folder_path + save_path: ./outputs diff --git a/evaluate.sh b/evaluate.sh new file mode 100644 index 0000000..9bd50ce --- /dev/null +++ b/evaluate.sh @@ -0,0 +1,3 @@ +export CUDA_VISIBLE_DEVICES=3 +torchrun --master-port 16016 --nproc_per_node 1 eval.py \ + --cfg-path /home/VLAI/thinhpn/STA/MiniGPT-4/eval_configs/minigptv2_benchmark_evaluation.yaml --dataset mvtech \ No newline at end of file diff --git a/minigpt4/datasets/datasets/vqa_datasets.py b/minigpt4/datasets/datasets/vqa_datasets.py index a8df3cc..93f46e7 100755 --- a/minigpt4/datasets/datasets/vqa_datasets.py +++ b/minigpt4/datasets/datasets/vqa_datasets.py @@ -11,6 +11,22 @@ import os from minigpt4.datasets.datasets.base_dataset import BaseDataset +class RefADEvalData(torch.utils.data.Dataset): + def __init__(self, loaded_data, vis_processor): + self.loaded_data = loaded_data + self.vis_processor = vis_processor + + def __len__(self): + return len(self.loaded_data) + + def __getitem__(self, idx): + data = self.loaded_data[idx] + sent = data["sents"] + img_id = data["img_id"] + image_path = data["img_path"] + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + return image, sent, img_id class VQADataset(BaseDataset): def __init__(self, vis_processor, text_processor, vis_root, ann_paths):