load in 8bit and run ViT-L in cpu to reduce vram use

This commit is contained in:
152334H 2023-04-17 13:52:09 +08:00
parent 017fa43d03
commit 6a8e6e49ff

View File

@ -84,7 +84,8 @@ class MiniGPT4(Blip2Base):
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
self.llama_model = LlamaForCausalLM.from_pretrained( self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model, torch_dtype=torch.float16 llama_model, torch_dtype=torch.float16,
load_in_8bit=True, device_map="auto"
) )
for name, param in self.llama_model.named_parameters(): for name, param in self.llama_model.named_parameters():
param.requires_grad = False param.requires_grad = False
@ -107,12 +108,17 @@ class MiniGPT4(Blip2Base):
self.prompt_list = [] self.prompt_list = []
def encode_img(self, image): def encode_img(self, image):
with self.maybe_autocast(): device = image.device
image_embeds = self.ln_vision(self.visual_encoder(image)) self.ln_vision.to("cpu")
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( self.ln_vision.float()
image.device self.visual_encoder.to("cpu")
) self.visual_encoder.float()
image = image.to("cpu")
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
with self.maybe_autocast():
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert( query_output = self.Qformer.bert(
query_embeds=query_tokens, query_embeds=query_tokens,