mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
226 lines
8.1 KiB
Python
226 lines
8.1 KiB
Python
|
import os
|
||
|
import re
|
||
|
import json
|
||
|
from collections import defaultdict
|
||
|
from PIL import Image
|
||
|
from tqdm import tqdm
|
||
|
import torch
|
||
|
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||
|
from torch.utils.data import DataLoader
|
||
|
from minigpt4.common.config import Config
|
||
|
from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser
|
||
|
from minigpt4.conversation.conversation import CONV_VISION_minigptv2
|
||
|
|
||
|
def list_of_str(arg):
|
||
|
return list(map(str, arg.split(',')))
|
||
|
|
||
|
parser = eval_parser()
|
||
|
parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate")
|
||
|
parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco")
|
||
|
parser.add_argument("--resample", action='store_true', help="resolution used in refcoco")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
cfg = Config(args)
|
||
|
|
||
|
model, vis_processor = init_model(args)
|
||
|
model.eval()
|
||
|
CONV_VISION = CONV_VISION_minigptv2
|
||
|
conv_temp = CONV_VISION.copy()
|
||
|
conv_temp.system = ""
|
||
|
save_path = cfg.run_cfg.save_path
|
||
|
|
||
|
os.makedirs(save_path, exist_ok=True)
|
||
|
|
||
|
|
||
|
class RefADEvalData(torch.utils.data.Dataset):
|
||
|
def __init__(self, loaded_data, vis_processor):
|
||
|
self.loaded_data = loaded_data
|
||
|
self.vis_processor = vis_processor
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.loaded_data)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
data = self.loaded_data[idx]
|
||
|
sent = "[detection] a defect or not-defect object and return the bounding boxes and its label. If not, bound around the object."
|
||
|
img_id = (data["class"] + "." + os.path.basename(data["image_path"]).split(".")[0])
|
||
|
fix_path = os.path.join("./data/MVTEC_det/images", "/".join(data["image_path"].split("/")[1:4]))
|
||
|
image = Image.open(fix_path).convert("RGB")
|
||
|
image = self.vis_processor(image)
|
||
|
return image, sent, img_id, data["class"], data["image_path"]
|
||
|
|
||
|
|
||
|
eval_file_path = cfg.evaluation_datasets_cfg["mvtec_ad"]["eval_file_path"]
|
||
|
batch_size = cfg.evaluation_datasets_cfg["mvtec_ad"]["batch_size"]
|
||
|
max_new_tokens = cfg.evaluation_datasets_cfg["mvtec_ad"]["max_new_tokens"]
|
||
|
|
||
|
|
||
|
# Adapt the data loading to the RefCOCO format
|
||
|
with open(eval_file_path, "r") as f:
|
||
|
mvtec_ad_data_for_regression = json.load(f)
|
||
|
|
||
|
# mvtec_ad_data_for_regression = mvtec_ad_data_for_regression[:10]
|
||
|
data = RefADEvalData(mvtec_ad_data_for_regression, vis_processor)
|
||
|
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
|
||
|
|
||
|
minigpt4_predict = defaultdict(list)
|
||
|
resamples = []
|
||
|
|
||
|
for images, questions, img_ids, labels, image_paths in tqdm(eval_dataloader):
|
||
|
texts = prepare_texts(questions, conv_temp)
|
||
|
answers = model.generate(images,
|
||
|
texts,
|
||
|
max_new_tokens=max_new_tokens,
|
||
|
do_sample=False)
|
||
|
for answer, img_id, question, labels, image_paths in zip(
|
||
|
answers, img_ids, questions, labels, image_paths
|
||
|
):
|
||
|
answer = answer.replace("<unk>", "").replace(" ", "").strip()
|
||
|
pattern = r"<p>(.*?)<\/p>\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}"
|
||
|
if re.match(pattern, answer):
|
||
|
minigpt4_predict[img_id].append(answer)
|
||
|
else:
|
||
|
resamples.append(
|
||
|
{"img_id": img_id, "class": labels, "img_path": image_paths}
|
||
|
)
|
||
|
|
||
|
if args.resample:
|
||
|
for i in range(20):
|
||
|
data = RefADEvalData(resamples, vis_processor)
|
||
|
resamples = []
|
||
|
eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)
|
||
|
|
||
|
for images, questions, img_ids, labels, image_paths in tqdm(
|
||
|
eval_dataloader
|
||
|
):
|
||
|
texts = prepare_texts(questions, conv_temp)
|
||
|
answers = model.generate(images,
|
||
|
texts,
|
||
|
max_new_tokens=max_new_tokens,
|
||
|
do_sample=False)
|
||
|
for answer, img_id, question, labels, image_paths in zip(
|
||
|
answers, img_ids, questions, labels, image_paths
|
||
|
):
|
||
|
answer = answer.replace("<unk>", "").replace(" ", "").strip()
|
||
|
pattern = r"<p>(.*?)<\/p>\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}"
|
||
|
if re.match(pattern, answer) or i == 4:
|
||
|
minigpt4_predict[img_id].append(answer)
|
||
|
else:
|
||
|
resamples.append({
|
||
|
"img_id": img_id,
|
||
|
"class": labels,
|
||
|
"img_path": image_paths
|
||
|
})
|
||
|
|
||
|
if len(resamples) == 0:
|
||
|
break
|
||
|
|
||
|
|
||
|
# Save predictions
|
||
|
file_save_path = os.path.join(save_path, "mvtec_ad_regression.json")
|
||
|
with open(file_save_path, "w") as f:
|
||
|
json.dump(minigpt4_predict, f, indent=4)
|
||
|
|
||
|
metric = MeanAveragePrecision(iou_type="bbox", class_metrics=True)
|
||
|
|
||
|
# Calculate metrics
|
||
|
count = 0
|
||
|
total = 0
|
||
|
res = args.res
|
||
|
for item in mvtec_ad_data_for_regression:
|
||
|
img_id = (item["class"] + "." + os.path.basename(item["image_path"]).split(".")[0])
|
||
|
label = item["class"]
|
||
|
is_broken = item["is_broken"]
|
||
|
outputs = minigpt4_predict[img_id]
|
||
|
|
||
|
# Determine ground truth bounding box and class
|
||
|
if not is_broken:
|
||
|
# If not broken, the bounding box is the whole image
|
||
|
gt_bbox = [0, 0, item["width"], item["height"]]
|
||
|
gt_class = 0 # Class 0 for "not-defect"
|
||
|
else:
|
||
|
gt_bbox = [
|
||
|
item["bbox"][0],
|
||
|
item["bbox"][1],
|
||
|
item["bbox"][2],
|
||
|
item["bbox"][3],
|
||
|
]
|
||
|
gt_class = 1 # Class 1 for "defect"
|
||
|
|
||
|
# Ground truth data for torchmetrics
|
||
|
gt_boxes = torch.tensor([gt_bbox], dtype=torch.float)
|
||
|
gt_labels = torch.tensor([gt_class], dtype=torch.int)
|
||
|
|
||
|
outputs = minigpt4_predict[img_id]
|
||
|
|
||
|
pred_boxes = []
|
||
|
pred_scores = []
|
||
|
pred_labels = []
|
||
|
|
||
|
for output in outputs:
|
||
|
match = re.search(
|
||
|
r"<p>(.*?)<\/p>\{<(\d{1,3})><(\d{1,3})><(\d{1,3})><(\d{1,3})>\}",
|
||
|
output,
|
||
|
)
|
||
|
if match:
|
||
|
try:
|
||
|
pred_class_str = match.group(1).strip()
|
||
|
|
||
|
# Determine predicted class based on the string
|
||
|
if "not-defect" in pred_class_str:
|
||
|
pred_class = 0
|
||
|
elif "defect" in pred_class_str:
|
||
|
pred_class = 1
|
||
|
else:
|
||
|
pred_class = -1 # Or some other value to indicate uncertain
|
||
|
|
||
|
pred_bbox = [
|
||
|
int(match.group(2)),
|
||
|
int(match.group(3)),
|
||
|
int(match.group(4)),
|
||
|
int(match.group(5)),
|
||
|
]
|
||
|
height = item["height"]
|
||
|
width = item["width"]
|
||
|
pred_bbox[0] = pred_bbox[0] / res * width
|
||
|
pred_bbox[1] = pred_bbox[1] / res * height
|
||
|
pred_bbox[2] = pred_bbox[2] / res * width
|
||
|
pred_bbox[3] = pred_bbox[3] / res * height
|
||
|
|
||
|
pred_boxes.append(pred_bbox)
|
||
|
pred_scores.append(1.0) # Assuming confidence score of 1
|
||
|
pred_labels.append(pred_class)
|
||
|
except Exception as e:
|
||
|
print(f"Error processing output: {output}, Error: {e}")
|
||
|
continue
|
||
|
|
||
|
# Convert lists to tensors
|
||
|
if pred_boxes:
|
||
|
pred_boxes = torch.tensor(pred_boxes, dtype=torch.float)
|
||
|
pred_scores = torch.tensor(pred_scores, dtype=torch.float)
|
||
|
pred_labels = torch.tensor(pred_labels, dtype=torch.int)
|
||
|
else:
|
||
|
# Create empty tensors if no predictions were made
|
||
|
pred_boxes = torch.empty((0, 4), dtype=torch.float)
|
||
|
pred_scores = torch.empty(0, dtype=torch.float)
|
||
|
pred_labels = torch.empty(0, dtype=torch.int)
|
||
|
|
||
|
# Update metric
|
||
|
metric.update(
|
||
|
[dict(boxes=pred_boxes, scores=pred_scores, labels=pred_labels)],
|
||
|
[dict(boxes=gt_boxes, labels=gt_labels)],
|
||
|
)
|
||
|
|
||
|
# Compute metric
|
||
|
result = metric.compute()
|
||
|
map_value = result["map"].item()
|
||
|
|
||
|
# Print class-wise metrics
|
||
|
# for i, class_map in enumerate(result["map_per_class"]):
|
||
|
# class_name = "defect" if i == 1 else "not-defect"
|
||
|
# print( f"mAP for {class_name}: {class_map.item() * 100 if not torch.isnan(class_map) else 0:.4f}",
|
||
|
# flush=True
|
||
|
# )
|
||
|
|
||
|
print(f"mAP: {map_value * 100}", flush=True)
|