import os import re import json from collections import defaultdict from PIL import Image from tqdm import tqdm import torch from torchmetrics.detection.mean_ap import MeanAveragePrecision from torch.utils.data import DataLoader from minigpt4.common.config import Config from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser from minigpt4.conversation.conversation import CONV_VISION_minigptv2 def list_of_str(arg): return list(map(str, arg.split(','))) parser = eval_parser() parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate") parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco") parser.add_argument("--resample", action='store_true', help="resolution used in refcoco") args = parser.parse_args() cfg = Config(args) model, vis_processor = init_model(args) model.eval() CONV_VISION = CONV_VISION_minigptv2 conv_temp = CONV_VISION.copy() conv_temp.system = "" save_path = cfg.run_cfg.save_path os.makedirs(save_path, exist_ok=True) class RefADEvalData(torch.utils.data.Dataset): def __init__(self, loaded_data, vis_processor): self.loaded_data = loaded_data self.vis_processor = vis_processor def __len__(self): return len(self.loaded_data) def __getitem__(self, idx): data = self.loaded_data[idx] sent = "[detection] a defect or not-defect object and return the bounding boxes and its label. If not, bound around the object." img_id = (data["class"] + "." + os.path.basename(data["image_path"]).split(".")[0]) fix_path = os.path.join("./data/MVTEC_det/images", "/".join(data["image_path"].split("/")[1:4])) image = Image.open(fix_path).convert("RGB") image = self.vis_processor(image) return image, sent, img_id, data["class"], data["image_path"] eval_file_path = cfg.evaluation_datasets_cfg["mvtec_ad"]["eval_file_path"] batch_size = cfg.evaluation_datasets_cfg["mvtec_ad"]["batch_size"] max_new_tokens = cfg.evaluation_datasets_cfg["mvtec_ad"]["max_new_tokens"] # Adapt the data loading to the RefCOCO format with open(eval_file_path, "r") as f: mvtec_ad_data_for_regression = json.load(f) # mvtec_ad_data_for_regression = mvtec_ad_data_for_regression[:10] data = RefADEvalData(mvtec_ad_data_for_regression, vis_processor) eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) minigpt4_predict = defaultdict(list) resamples = [] for images, questions, img_ids, labels, image_paths in tqdm(eval_dataloader): texts = prepare_texts(questions, conv_temp) answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) for answer, img_id, question, labels, image_paths in zip( answers, img_ids, questions, labels, image_paths ): answer = answer.replace("", "").replace(" ", "").strip() pattern = r"

(.*?)<\/p>\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}" if re.match(pattern, answer): minigpt4_predict[img_id].append(answer) else: resamples.append( {"img_id": img_id, "class": labels, "img_path": image_paths} ) if args.resample: for i in range(20): data = RefADEvalData(resamples, vis_processor) resamples = [] eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) for images, questions, img_ids, labels, image_paths in tqdm( eval_dataloader ): texts = prepare_texts(questions, conv_temp) answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) for answer, img_id, question, labels, image_paths in zip( answers, img_ids, questions, labels, image_paths ): answer = answer.replace("", "").replace(" ", "").strip() pattern = r"

(.*?)<\/p>\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}" if re.match(pattern, answer) or i == 4: minigpt4_predict[img_id].append(answer) else: resamples.append({ "img_id": img_id, "class": labels, "img_path": image_paths }) if len(resamples) == 0: break # Save predictions file_save_path = os.path.join(save_path, "mvtec_ad_regression.json") with open(file_save_path, "w") as f: json.dump(minigpt4_predict, f, indent=4) metric = MeanAveragePrecision(iou_type="bbox", class_metrics=True) # Calculate metrics count = 0 total = 0 res = args.res for item in mvtec_ad_data_for_regression: img_id = (item["class"] + "." + os.path.basename(item["image_path"]).split(".")[0]) label = item["class"] is_broken = item["is_broken"] outputs = minigpt4_predict[img_id] # Determine ground truth bounding box and class if not is_broken: # If not broken, the bounding box is the whole image gt_bbox = [0, 0, item["width"], item["height"]] gt_class = 0 # Class 0 for "not-defect" else: gt_bbox = [ item["bbox"][0], item["bbox"][1], item["bbox"][2], item["bbox"][3], ] gt_class = 1 # Class 1 for "defect" # Ground truth data for torchmetrics gt_boxes = torch.tensor([gt_bbox], dtype=torch.float) gt_labels = torch.tensor([gt_class], dtype=torch.int) outputs = minigpt4_predict[img_id] pred_boxes = [] pred_scores = [] pred_labels = [] for output in outputs: match = re.search( r"

(.*?)<\/p>\{<(\d{1,3})><(\d{1,3})><(\d{1,3})><(\d{1,3})>\}", output, ) if match: try: pred_class_str = match.group(1).strip() # Determine predicted class based on the string if "not-defect" in pred_class_str: pred_class = 0 elif "defect" in pred_class_str: pred_class = 1 else: pred_class = -1 # Or some other value to indicate uncertain pred_bbox = [ int(match.group(2)), int(match.group(3)), int(match.group(4)), int(match.group(5)), ] height = item["height"] width = item["width"] pred_bbox[0] = pred_bbox[0] / res * width pred_bbox[1] = pred_bbox[1] / res * height pred_bbox[2] = pred_bbox[2] / res * width pred_bbox[3] = pred_bbox[3] / res * height pred_boxes.append(pred_bbox) pred_scores.append(1.0) # Assuming confidence score of 1 pred_labels.append(pred_class) except Exception as e: print(f"Error processing output: {output}, Error: {e}") continue # Convert lists to tensors if pred_boxes: pred_boxes = torch.tensor(pred_boxes, dtype=torch.float) pred_scores = torch.tensor(pred_scores, dtype=torch.float) pred_labels = torch.tensor(pred_labels, dtype=torch.int) else: # Create empty tensors if no predictions were made pred_boxes = torch.empty((0, 4), dtype=torch.float) pred_scores = torch.empty(0, dtype=torch.float) pred_labels = torch.empty(0, dtype=torch.int) # Update metric metric.update( [dict(boxes=pred_boxes, scores=pred_scores, labels=pred_labels)], [dict(boxes=gt_boxes, labels=gt_labels)], ) # Compute metric result = metric.compute() map_value = result["map"].item() # Print class-wise metrics # for i, class_map in enumerate(result["map_per_class"]): # class_name = "defect" if i == 1 else "not-defect" # print( f"mAP for {class_name}: {class_map.item() * 100 if not torch.isnan(class_map) else 0:.4f}", # flush=True # ) print(f"mAP: {map_value * 100}", flush=True)