MiniGPT-4/GroundingModel.py
Zhijie Lin ccf3db9207 test
2023-05-25 17:13:50 +08:00

120 lines
4.8 KiB
Python

import PIL
import torch
import torch.nn as nn
from PIL import ImageDraw, ImageFont
import groundingdino.datasets.transforms as T
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
from groundingdino.util.inference import annotate, load_image, predict
from segment_anything import build_sam, SamPredictor
from segment_anything.utils.amg import remove_small_regions
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import numpy as np
def load_groundingdino_model(model_config_path, model_checkpoint_path):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
print('loading GroundingDINO:', load_res)
_ = model.eval()
return model
class GroundingModule(nn.Module):
def __init__(self, device='cpu'):
super().__init__()
self.device = device
sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
groundingdino_checkpoint = "./checkpoints/groundingdino_swint_ogc.pth"
groundingdino_config_file = "./groundingdino/config/GroundingDINO_SwinT_OGC.py"
self.grounding_model = load_groundingdino_model(groundingdino_config_file, groundingdino_checkpoint).to(device)
self.sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
def prompt2mask(self, original_image, prompt, box_threshold=0.25, text_threshold=0.25, num_boxes=10):
def image_transform_grounding(init_image):
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image, _ = transform(init_image, None) # 3, h, w
return init_image, image
image_np = np.array(original_image, dtype=np.uint8)
prompt = prompt.lower()
prompt = prompt.strip()
if not prompt.endswith("."):
prompt = prompt + "."
_, image_tensor = image_transform_grounding(original_image)
boxes, logits, phrases = predict(self.grounding_model,
image_tensor, prompt, box_threshold, text_threshold, device=self.device)
print(phrases)
# from PIL import Image, ImageDraw, ImageFont
H, W = original_image.size[1], original_image.size[0]
draw_img = original_image.copy()
draw = ImageDraw.Draw(draw_img)
for box in boxes:
# from 0..1 to 0..W, 0..H
box = box * torch.Tensor([W, H, W, H])
# from xywh to xyxy
box[:2] -= box[2:] / 2
box[2:] += box[:2]
# random color
color = tuple(np.random.randint(0, 255, size=3).tolist())
# draw
x0, y0, x1, y1 = box
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
if boxes.size(0) > 0:
boxes = boxes * torch.Tensor([W, H, W, H])
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]
self.sam_predictor.set_image(image_np)
transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2])
masks, _, _ = self.sam_predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.device),
multimask_output=False,
)
# remove small disconnected regions and holes
fine_masks = []
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
masks = torch.from_numpy(masks)
num_obj = min(len(logits), num_boxes)
mask_map = None
full_img = None
for obj_ind in range(num_obj):
# box = boxes[obj_ind]
m = masks[obj_ind][0]
if full_img is None:
full_img = np.zeros((m.shape[0], m.shape[1], 3))
mask_map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
mask_map[m != 0] = obj_ind + 1
color_mask = np.random.random((1, 3)).tolist()[0]
full_img[m != 0] = color_mask
full_img = (full_img * 255).astype(np.uint8)
full_img = PIL.Image.fromarray(full_img)
PIL.Image.blend(draw_img, full_img, 0.5)
return draw_img