mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
load in 8bit and run ViT-L in cpu to reduce vram use
This commit is contained in:
parent
017fa43d03
commit
6a8e6e49ff
@ -84,7 +84,8 @@ class MiniGPT4(Blip2Base):
|
||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
||||
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model, torch_dtype=torch.float16
|
||||
llama_model, torch_dtype=torch.float16,
|
||||
load_in_8bit=True, device_map="auto"
|
||||
)
|
||||
for name, param in self.llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
@ -107,12 +108,17 @@ class MiniGPT4(Blip2Base):
|
||||
self.prompt_list = []
|
||||
|
||||
def encode_img(self, image):
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
||||
image.device
|
||||
)
|
||||
device = image.device
|
||||
self.ln_vision.to("cpu")
|
||||
self.ln_vision.float()
|
||||
self.visual_encoder.to("cpu")
|
||||
self.visual_encoder.float()
|
||||
image = image.to("cpu")
|
||||
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||
|
||||
with self.maybe_autocast():
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
|
Loading…
Reference in New Issue
Block a user