mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
264 lines
9.0 KiB
Python
264 lines
9.0 KiB
Python
import os
|
|
import sys
|
|
import json
|
|
import argparse
|
|
import pathlib
|
|
import random
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
|
|
import sentencepiece; import pytorch_lightning as pl
|
|
|
|
import torchmetrics.functional as MF
|
|
|
|
from load_aokvqa import load_aokvqa
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
|
parser.add_argument('--vocab', type=argparse.FileType('r'), required=True)
|
|
parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True)
|
|
#
|
|
parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True)
|
|
parser.add_argument('--clip-model-type', type=str,
|
|
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
|
|
dest='clip_model_type', required=('clip' in sys.argv))
|
|
parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features')
|
|
parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features')
|
|
parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features')
|
|
#
|
|
parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True)
|
|
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True)
|
|
# Defaults
|
|
parser.add_argument('--bs', type=int, default=128, dest='batch_size')
|
|
parser.add_argument('--lr', type=float, default=0.01)
|
|
parser.add_argument('--epochs', type=int, default=500)
|
|
parser.add_argument('--gpus', type=int, default=1)
|
|
args = parser.parse_args()
|
|
|
|
pl.seed_everything(1)
|
|
vocab = args.vocab.read().splitlines()
|
|
|
|
## Data loading
|
|
|
|
dm = AokvqaEmbeddingsDataModule(
|
|
args.aokvqa_dir,
|
|
args.train_features,
|
|
args.val_features,
|
|
args.objective,
|
|
args.backbone,
|
|
args.inputs,
|
|
vocab,
|
|
args.vocab_features,
|
|
batch_size=args.batch_size,
|
|
num_workers=16
|
|
)
|
|
|
|
## Model definition
|
|
|
|
model = LinearClassifier(
|
|
args.objective,
|
|
args.backbone,
|
|
args.clip_model_type,
|
|
args.inputs,
|
|
len(vocab),
|
|
args.lr
|
|
)
|
|
|
|
## Training and testing loops
|
|
|
|
logger = pl.loggers.TensorBoardLogger(
|
|
args.log_dir,
|
|
name=f'{args.backbone}-{args.objective}',
|
|
version=f"inputs:{'+'.join(args.inputs)}"
|
|
)
|
|
|
|
trainer = pl.Trainer(
|
|
logger=logger,
|
|
gpus=args.gpus,
|
|
max_epochs=args.epochs,
|
|
callbacks=[
|
|
pl.callbacks.ModelCheckpoint(
|
|
monitor="val_acc",
|
|
filename="{epoch:02d}-{val_acc:.2f}",
|
|
mode="max"
|
|
)
|
|
],
|
|
)
|
|
|
|
trainer.fit(model, dm)
|
|
|
|
|
|
class AokvqaEmbeddingsDataset(Dataset):
|
|
def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features):
|
|
|
|
aokvqa_set = load_aokvqa(aokvqa_dir, split)
|
|
|
|
assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \
|
|
or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \
|
|
or ( backbone == 'clip' )
|
|
|
|
embeddings = torch.load(input_features)
|
|
if backbone == 'clip':
|
|
for q in embeddings.keys():
|
|
embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True)
|
|
embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True)
|
|
if objective == 'contrastive':
|
|
vocab_embeddings = torch.load(vocab_features)
|
|
vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True)
|
|
|
|
self.objective = objective
|
|
self.vocab_len = len(vocab)
|
|
|
|
self.embeddings = []
|
|
self.answers = []
|
|
|
|
for o in aokvqa_set:
|
|
correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers'])
|
|
correct_answers = [vocab.index(a) for a in correct_answers if a in vocab]
|
|
if self.objective == 'contrastive':
|
|
correct_answers = [vocab_embeddings[a] for a in correct_answers]
|
|
if len(correct_answers) == 0: continue
|
|
self.answers.append(correct_answers)
|
|
|
|
q = o['question_id']
|
|
if 'question' in inputs and 'image' in inputs:
|
|
e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
|
|
elif 'question' in inputs and 'image' not in inputs:
|
|
e = embeddings[q]['question']
|
|
elif 'question' not in inputs and 'image' in inputs:
|
|
e = embeddings[q]['image']
|
|
self.embeddings.append(e)
|
|
|
|
def __getitem__(self, index):
|
|
e = self.embeddings[index]
|
|
a = self.answers[index]
|
|
if self.objective == 'classifier':
|
|
a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0)
|
|
elif self.objective == 'contrastive':
|
|
a = random.sample(a, 1)[0]
|
|
return e, a
|
|
|
|
def __len__(self):
|
|
return len(self.embeddings)
|
|
|
|
|
|
class AokvqaEmbeddingsDataModule(pl.LightningDataModule):
|
|
|
|
def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0):
|
|
super().__init__()
|
|
self.aokvqa_dir = aokvqa_dir
|
|
self.train_features = train_features
|
|
self.val_features = val_features
|
|
self.objective = objective
|
|
self.backbone = backbone
|
|
self.inputs = inputs
|
|
self.vocab = vocab
|
|
self.vocab_features = vocab_features
|
|
self.batch_size = batch_size
|
|
self.num_workers = num_workers
|
|
|
|
def setup(self, stage=None):
|
|
self.train_dataset = AokvqaEmbeddingsDataset(
|
|
self.aokvqa_dir, 'train', self.train_features, self.objective,
|
|
self.backbone, self.inputs, self.vocab, self.vocab_features
|
|
)
|
|
self.val_dataset = AokvqaEmbeddingsDataset(
|
|
self.aokvqa_dir, 'val', self.val_features, self.objective,
|
|
self.backbone, self.inputs, self.vocab, self.vocab_features
|
|
)
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(
|
|
self.train_dataset, batch_size=self.batch_size, shuffle=True,
|
|
num_workers=int(0.8 * self.num_workers)
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(
|
|
self.val_dataset, batch_size=self.batch_size, shuffle=False,
|
|
num_workers=int(0.2 * self.num_workers)
|
|
)
|
|
|
|
|
|
class LinearClassifier(pl.LightningModule):
|
|
def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001):
|
|
super().__init__()
|
|
self.save_hyperparameters(ignore=['lr'])
|
|
self.lr = lr
|
|
|
|
if self.hparams.backbone == 'clip':
|
|
clip_dim = {
|
|
'RN50' : 1024,
|
|
'RN50x4' : 640,
|
|
'RN50x16' : 768,
|
|
'RN50x64' : 1024,
|
|
'RN101' : 512,
|
|
'ViT-B/32' : 512,
|
|
'ViT-B/16' : 512,
|
|
'ViT-L/14' : 768,
|
|
'ViT-L/14@336px' : 768,
|
|
}[clip_model_type]
|
|
emb_dim = clip_dim * len(inputs)
|
|
elif self.hparams.backbone == 'resnet':
|
|
emb_dim = 2048
|
|
elif self.hparams.backbone == 'bert':
|
|
emb_dim = 768
|
|
|
|
if self.hparams.objective == 'classifier':
|
|
out_dim = vocab_len
|
|
elif self.hparams.objective == 'contrastive':
|
|
out_dim = clip_dim
|
|
|
|
self.linear = nn.Linear(emb_dim, out_dim)
|
|
|
|
def forward(self, x):
|
|
x = self.linear(x)
|
|
if self.hparams.objective == 'classifier':
|
|
x = torch.sigmoid(x)
|
|
return x
|
|
|
|
def compute_loss(self, batch):
|
|
x, y = batch
|
|
|
|
y_pred = self.forward(x)
|
|
|
|
if self.hparams.objective == 'classifier':
|
|
loss = F.binary_cross_entropy(y_pred, y.float())
|
|
elif self.hparams.objective == 'contrastive':
|
|
indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device)
|
|
sim = (y_pred @ y.T).softmax(dim=-1)
|
|
loss = F.cross_entropy(sim, indices)
|
|
|
|
if self.hparams.objective == 'classifier':
|
|
acc = MF.f1_score(y_pred, y)
|
|
elif self.hparams.objective == 'contrastive':
|
|
acc = torch.mean(sim[indices, indices])
|
|
|
|
return loss, acc
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
loss, acc = self.compute_loss(batch)
|
|
self.log("train_loss", loss)
|
|
self.log("train_acc", acc)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
loss, acc = self.compute_loss(batch)
|
|
self.log("val_loss", loss)
|
|
self.log("val_acc", acc)
|
|
return loss
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
|
return optimizer
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|