mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
127 lines
4.7 KiB
Python
127 lines
4.7 KiB
Python
import sys
|
|
import os
|
|
import argparse
|
|
import pathlib
|
|
from tqdm import tqdm
|
|
import json
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
|
|
import sentencepiece; import pytorch_lightning as pl; import clip
|
|
|
|
from transfer_experiments.train import LinearClassifier
|
|
from load_aokvqa import load_aokvqa
|
|
from evaluation.remap_predictions import map_to_choices
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
|
parser.add_argument('--features', type=pathlib.Path, required=True)
|
|
parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
|
|
#
|
|
parser_weights = parser.add_mutually_exclusive_group(required=True)
|
|
|
|
parser_weights.add_argument('--ckpt', type=pathlib.Path, dest='checkpoint_path')
|
|
|
|
parser_weights.add_argument('--zero-shot', action='store_true', dest='clip_zero_shot')
|
|
parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=('--zero-shot' in sys.argv))
|
|
#
|
|
parser.add_argument('--vocab', type=argparse.FileType('r'))
|
|
parser.add_argument('--vocab-features', type=pathlib.Path, dest='vocab_features')
|
|
parser.add_argument('--mc', action='store_true', dest='multiple_choice')
|
|
|
|
parser.add_argument('--clip-model-type', type=str,
|
|
choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
|
|
dest='clip_model_type', required=('--zero-shot' in sys.argv and '--mc' in sys.argv))
|
|
#
|
|
args = parser.parse_args()
|
|
|
|
|
|
## Load dataset
|
|
|
|
aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
|
|
|
|
## Load models
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
if args.checkpoint_path is not None:
|
|
classifier = LinearClassifier.load_from_checkpoint(args.checkpoint_path)
|
|
classifier.to(device)
|
|
hp = classifier.hparams
|
|
elif args.clip_zero_shot:
|
|
classifier = nn.Identity().to(device)
|
|
hp = pl.utilities.AttributeDict(backbone='clip', clip_model_type=args.clip_model_type, objective='zero-shot', inputs=args.inputs)
|
|
|
|
# Load input features
|
|
|
|
embeddings = torch.load(args.features)
|
|
if hp.backbone == 'clip':
|
|
for q in embeddings.keys():
|
|
embeddings[q]['question'] = embeddings[q]['question'] / embeddings[q]['question'].norm(dim=-1, keepdim=True)
|
|
embeddings[q]['image'] = embeddings[q]['image'] / embeddings[q]['image'].norm(dim=-1, keepdim=True)
|
|
|
|
# Load vocab, vocab features, clip
|
|
|
|
if (hp.objective == 'classifier') or \
|
|
(hp.objective in ['contrastive', 'zero-shot'] and args.multiple_choice is False):
|
|
vocab = args.vocab.read().splitlines()
|
|
|
|
if hp.objective in ['contrastive', 'zero-shot']:
|
|
if args.multiple_choice is False:
|
|
vocab_features = torch.load(args.vocab_features).cpu()
|
|
vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
|
|
else:
|
|
clip_model = clip.load(hp.clip_model_type, device=device)[0]
|
|
logit_scale = clip_model.logit_scale.exp().cpu()
|
|
|
|
## Prediction loop
|
|
|
|
predictions = {}
|
|
|
|
with torch.no_grad():
|
|
for o in tqdm(aokvqa_set):
|
|
q = o['question_id']
|
|
|
|
# Load input embedding (from question / image)
|
|
if hp.objective == 'zero-shot' and ('question' in hp.inputs and 'image' in hp.inputs):
|
|
e = embeddings[q]['question'] + embeddings[q]['image']
|
|
elif 'question' in hp.inputs and 'image' in hp.inputs:
|
|
e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
|
|
elif 'question' in hp.inputs:
|
|
e = embeddings[q]['question']
|
|
elif 'image' in hp.inputs:
|
|
e = embeddings[q]['image']
|
|
|
|
# Pass inputs through model
|
|
e = e.unsqueeze(0).to(device)
|
|
x = classifier(e)[0].cpu()
|
|
|
|
# Predict
|
|
if hp.objective in ['contrastive', 'zero-shot']:
|
|
if args.multiple_choice:
|
|
vocab = o['choices']
|
|
# Encode choices
|
|
vocab_features = clip.tokenize(vocab).to(device)
|
|
vocab_features = torch.stack([
|
|
clip_model.encode_text(v.unsqueeze(0)) for v in vocab_features
|
|
], dim=1)[0]
|
|
vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
|
|
vocab_features = vocab_features.float().cpu()
|
|
|
|
x = logit_scale * x @ vocab_features.t()
|
|
x = x.softmax(dim=-1)
|
|
|
|
predictions[q] = vocab[x.argmax().item()]
|
|
|
|
## Save and evaluate predictions
|
|
|
|
# Map prediction to nearest neighbor choice (by word embeddings)
|
|
if args.multiple_choice and hp.objective == 'classifier':
|
|
predictions = map_to_choices(aokvqa_set, predictions)
|
|
|
|
json.dump(predictions, args.output_file)
|