diff --git a/GroundingModel.py b/GroundingModel.py index f14b844..555e12b 100644 --- a/GroundingModel.py +++ b/GroundingModel.py @@ -54,7 +54,7 @@ class GroundingModule(nn.Module): prompt = prompt + "." _, image_tensor = image_transform_grounding(original_image) boxes, logits, phrases = predict(self.grounding_model, - image_tensor, prompt, box_threshold, text_threshold, device=self.device) + image_tensor, prompt, box_threshold, text_threshold, device='cpu') print(phrases) # from PIL import Image, ImageDraw, ImageFont H, W = original_image.size[1], original_image.size[0]