mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
120 lines
4.8 KiB
Python
120 lines
4.8 KiB
Python
import PIL
|
|
import torch
|
|
import torch.nn as nn
|
|
from PIL import ImageDraw, ImageFont
|
|
|
|
import groundingdino.datasets.transforms as T
|
|
from groundingdino.models import build_model
|
|
from groundingdino.util import box_ops
|
|
from groundingdino.util.slconfig import SLConfig
|
|
from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
|
|
from groundingdino.util.inference import annotate, load_image, predict
|
|
from segment_anything import build_sam, SamPredictor
|
|
from segment_anything.utils.amg import remove_small_regions
|
|
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
|
|
|
|
import numpy as np
|
|
|
|
|
|
def load_groundingdino_model(model_config_path, model_checkpoint_path):
|
|
args = SLConfig.fromfile(model_config_path)
|
|
model = build_model(args)
|
|
checkpoint = torch.load(model_checkpoint_path, map_location="cpu")
|
|
load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False)
|
|
print('loading GroundingDINO:', load_res)
|
|
_ = model.eval()
|
|
return model
|
|
|
|
|
|
class GroundingModule(nn.Module):
|
|
def __init__(self, device='cpu'):
|
|
super().__init__()
|
|
self.device = device
|
|
sam_checkpoint = "./checkpoints/sam_vit_h_4b8939.pth"
|
|
groundingdino_checkpoint = "./checkpoints/groundingdino_swint_ogc.pth"
|
|
groundingdino_config_file = "./groundingdino/config/GroundingDINO_SwinT_OGC.py"
|
|
|
|
self.grounding_model = load_groundingdino_model(groundingdino_config_file, groundingdino_checkpoint).to(device)
|
|
self.sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(device))
|
|
|
|
def prompt2mask(self, original_image, prompt, box_threshold=0.25, text_threshold=0.25, num_boxes=10):
|
|
def image_transform_grounding(init_image):
|
|
transform = T.Compose([
|
|
T.RandomResize([800], max_size=1333),
|
|
T.ToTensor(),
|
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
image, _ = transform(init_image, None) # 3, h, w
|
|
return init_image, image
|
|
|
|
image_np = np.array(original_image, dtype=np.uint8)
|
|
prompt = prompt.lower()
|
|
prompt = prompt.strip()
|
|
if not prompt.endswith("."):
|
|
prompt = prompt + "."
|
|
_, image_tensor = image_transform_grounding(original_image)
|
|
boxes, logits, phrases = predict(self.grounding_model,
|
|
image_tensor, prompt, box_threshold, text_threshold, device='cpu')
|
|
print(phrases)
|
|
# from PIL import Image, ImageDraw, ImageFont
|
|
H, W = original_image.size[1], original_image.size[0]
|
|
|
|
draw_img = original_image.copy()
|
|
draw = ImageDraw.Draw(draw_img)
|
|
for box in boxes:
|
|
# from 0..1 to 0..W, 0..H
|
|
box = box * torch.Tensor([W, H, W, H])
|
|
# from xywh to xyxy
|
|
box[:2] -= box[2:] / 2
|
|
box[2:] += box[:2]
|
|
# random color
|
|
color = tuple(np.random.randint(0, 255, size=3).tolist())
|
|
# draw
|
|
x0, y0, x1, y1 = box
|
|
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
|
|
|
|
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
|
|
|
|
if boxes.size(0) > 0:
|
|
boxes = boxes * torch.Tensor([W, H, W, H])
|
|
boxes[:, :2] = boxes[:, :2] - boxes[:, 2:] / 2
|
|
boxes[:, 2:] = boxes[:, 2:] + boxes[:, :2]
|
|
|
|
self.sam_predictor.set_image(image_np)
|
|
|
|
transformed_boxes = self.sam_predictor.transform.apply_boxes_torch(boxes, image_np.shape[:2])
|
|
masks, _, _ = self.sam_predictor.predict_torch(
|
|
point_coords=None,
|
|
point_labels=None,
|
|
boxes=transformed_boxes.to(self.device),
|
|
multimask_output=False,
|
|
)
|
|
|
|
# remove small disconnected regions and holes
|
|
fine_masks = []
|
|
for mask in masks.to('cpu').numpy(): # masks: [num_masks, 1, h, w]
|
|
fine_masks.append(remove_small_regions(mask[0], 400, mode="holes")[0])
|
|
masks = np.stack(fine_masks, axis=0)[:, np.newaxis]
|
|
masks = torch.from_numpy(masks)
|
|
|
|
num_obj = min(len(logits), num_boxes)
|
|
mask_map = None
|
|
|
|
full_img = None
|
|
for obj_ind in range(num_obj):
|
|
# box = boxes[obj_ind]
|
|
|
|
m = masks[obj_ind][0]
|
|
|
|
if full_img is None:
|
|
full_img = np.zeros((m.shape[0], m.shape[1], 3))
|
|
mask_map = np.zeros((m.shape[0], m.shape[1]), dtype=np.uint16)
|
|
|
|
mask_map[m != 0] = obj_ind + 1
|
|
color_mask = np.random.random((1, 3)).tolist()[0]
|
|
full_img[m != 0] = color_mask
|
|
full_img = (full_img * 255).astype(np.uint8)
|
|
full_img = PIL.Image.fromarray(full_img)
|
|
PIL.Image.blend(draw_img, full_img, 0.5)
|
|
return draw_img
|