add argument to switch 8bit

This commit is contained in:
Deyao Zhu 2023-04-17 16:54:00 +03:00
parent 3e03c8327f
commit b76d5c5a50
2 changed files with 27 additions and 9 deletions

View File

@ -5,6 +5,7 @@ model:
freeze_qformer: True
max_txt_len: 160
end_sym: "###"
low_resource: True
prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/pretrained/ckpt/'

View File

@ -36,11 +36,13 @@ class MiniGPT4(Blip2Base):
prompt_path="",
prompt_template="",
max_txt_len=32,
low_resource=False, # use 8 bit and put vit in cpu
end_sym='\n',
):
super().__init__()
self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource
print('Loading VIT')
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
@ -83,10 +85,19 @@ class MiniGPT4(Blip2Base):
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model, torch_dtype=torch.float16,
load_in_8bit=True, device_map="auto"
)
if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto"
)
else:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float16,
)
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
print('Loading LLAMA Done')
@ -107,18 +118,22 @@ class MiniGPT4(Blip2Base):
else:
self.prompt_list = []
def encode_img(self, image):
device = image.device
def vit_to_cpu(self):
self.ln_vision.to("cpu")
self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
image = image.to("cpu")
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
def encode_img(self, image):
device = image.device
if self.low_resource:
self.vit_to_cpu()
image = image.to("cpu")
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
@ -216,6 +231,7 @@ class MiniGPT4(Blip2Base):
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
@ -236,6 +252,7 @@ class MiniGPT4(Blip2Base):
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym
)