diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py
new file mode 100644
index 0000000..07ca21d
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvalDemo.py
@@ -0,0 +1,89 @@
+# coding: utf-8
+
+import sys
+dataDir = '../../VQA'
+sys.path.insert(0, '%s/PythonHelperTools/vqaTools' %(dataDir))
+from vqa import VQA
+from vqaEvaluation.vqaEval import VQAEval
+import matplotlib.pyplot as plt
+import skimage.io as io
+import json
+import random
+import os
+
+# set up file names and paths
+versionType ='v2_' # this should be '' when using VQA v2.0 dataset
+taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
+dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
+dataSubType ='train2014'
+annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
+quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
+imgDir ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
+resultType ='fake'
+fileTypes = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']
+
+# An example result json file has been provided in './Results' folder.
+
+[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
+resultType, fileType) for fileType in fileTypes]
+
+# create vqa object and vqaRes object
+vqa = VQA(annFile, quesFile)
+vqaRes = vqa.loadRes(resFile, quesFile)
+
+# create vqaEval object by taking vqa and vqaRes
+vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2
+
+# evaluate results
+"""
+If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
+By default it uses all the question ids in annotation file
+"""
+vqaEval.evaluate()
+
+# print accuracies
+print "\n"
+print "Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])
+print "Per Question Type Accuracy is the following:"
+for quesType in vqaEval.accuracy['perQuestionType']:
+ print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
+print "\n"
+print "Per Answer Type Accuracy is the following:"
+for ansType in vqaEval.accuracy['perAnswerType']:
+ print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])
+print "\n"
+# demo how to use evalQA to retrieve low score result
+evals = [quesId for quesId in vqaEval.evalQA if vqaEval.evalQA[quesId]<35] #35 is per question percentage accuracy
+if len(evals) > 0:
+ print 'ground truth answers'
+ randomEval = random.choice(evals)
+ randomAnn = vqa.loadQA(randomEval)
+ vqa.showQA(randomAnn)
+
+ print '\n'
+ print 'generated answer (accuracy %.02f)'%(vqaEval.evalQA[randomEval])
+ ann = vqaRes.loadQA(randomEval)[0]
+ print "Answer: %s\n" %(ann['answer'])
+
+ imgId = randomAnn[0]['image_id']
+ imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
+ if os.path.isfile(imgDir + imgFilename):
+ I = io.imread(imgDir + imgFilename)
+ plt.imshow(I)
+ plt.axis('off')
+ plt.show()
+
+# plot accuracy for various question types
+plt.bar(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].values(), align='center')
+plt.xticks(range(len(vqaEval.accuracy['perQuestionType'])), vqaEval.accuracy['perQuestionType'].keys(), rotation='0',fontsize=10)
+plt.title('Per Question Type Accuracy', fontsize=10)
+plt.xlabel('Question Types', fontsize=10)
+plt.ylabel('Accuracy', fontsize=10)
+plt.show()
+
+# save evaluation results to ./Results folder
+json.dump(vqaEval.accuracy, open(accuracyFile, 'w'))
+json.dump(vqaEval.evalQA, open(evalQAFile, 'w'))
+json.dump(vqaEval.evalQuesType, open(evalQuesTypeFile, 'w'))
+json.dump(vqaEval.evalAnsType, open(evalAnsTypeFile, 'w'))
+
diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py
new file mode 100644
index 0000000..148424d
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__init__.py
@@ -0,0 +1 @@
+author='aagrawal'
diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__pycache__/__init__.cpython-39.pyc b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..f09b0c2
Binary files /dev/null and b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__pycache__/vqaEval.cpython-39.pyc b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__pycache__/vqaEval.cpython-39.pyc
new file mode 100644
index 0000000..8dd0808
Binary files /dev/null and b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/__pycache__/vqaEval.cpython-39.pyc differ
diff --git a/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
new file mode 100644
index 0000000..8a65604
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonEvaluationTools/vqaEvaluation/vqaEval.py
@@ -0,0 +1,192 @@
+# coding=utf-8
+
+__author__='aagrawal'
+
+import re
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
+import sys
+
+
+class VQAEval:
+ def __init__(self, vqa, vqaRes, n=2):
+ self.n = n
+ self.accuracy = {}
+ self.evalQA = {}
+ self.evalQuesType = {}
+ self.evalAnsType = {}
+ self.vqa = vqa
+ self.vqaRes = vqaRes
+ self.params = {'question_id': vqa.getQuesIds()}
+ self.contractions = {"aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", \
+ "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", \
+ "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", \
+ "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", \
+ "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", \
+ "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", \
+ "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", \
+ "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", \
+ "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", \
+ "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", \
+ "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", \
+ "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", \
+ "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", \
+ "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", \
+ "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", \
+ "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", \
+ "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", \
+ "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", \
+ "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", \
+ "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", \
+ "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", \
+ "youll": "you'll", "youre": "you're", "youve": "you've"}
+ self.manualMap = { 'none': '0',
+ 'zero': '0',
+ 'one': '1',
+ 'two': '2',
+ 'three': '3',
+ 'four': '4',
+ 'five': '5',
+ 'six': '6',
+ 'seven': '7',
+ 'eight': '8',
+ 'nine': '9',
+ 'ten': '10'
+ }
+ self.articles = ['a',
+ 'an',
+ 'the'
+ ]
+
+
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
+ self.commaStrip = re.compile("(\d)(\,)(\d)")
+ self.punct = [';', r"/", '[', ']', '"', '{', '}',
+ '(', ')', '=', '+', '\\', '_', '-',
+ '>', '<', '@', '`', ',', '?', '!']
+
+
+ def evaluate(self, quesIds=None):
+ if quesIds == None:
+ quesIds = [quesId for quesId in self.params['question_id']]
+ gts = {}
+ res = {}
+ for quesId in quesIds:
+ gts[quesId] = self.vqa.qa[quesId]
+ res[quesId] = self.vqaRes.qa[quesId]
+
+ # =================================================
+ # Compute accuracy
+ # =================================================
+ accQA = []
+ accQuesType = {}
+ accAnsType = {}
+ # print "computing accuracy"
+ step = 0
+ for quesId in quesIds:
+ for ansDic in gts[quesId]['answers']:
+ ansDic['answer'] = ansDic['answer'].replace('\n', ' ')
+ ansDic['answer'] = ansDic['answer'].replace('\t', ' ')
+ ansDic['answer'] = ansDic['answer'].strip()
+ resAns = res[quesId]['answer']
+ resAns = resAns.replace('\n', ' ')
+ resAns = resAns.replace('\t', ' ')
+ resAns = resAns.strip()
+ gtAcc = []
+ gtAnswers = [ans['answer'] for ans in gts[quesId]['answers']]
+
+ if len(set(gtAnswers)) > 1:
+ for ansDic in gts[quesId]['answers']:
+ ansDic['answer'] = self.processPunctuation(ansDic['answer'])
+ ansDic['answer'] = self.processDigitArticle(ansDic['answer'])
+ resAns = self.processPunctuation(resAns)
+ resAns = self.processDigitArticle(resAns)
+
+ for gtAnsDatum in gts[quesId]['answers']:
+ otherGTAns = [item for item in gts[quesId]['answers'] if item!=gtAnsDatum]
+ matchingAns = [item for item in otherGTAns if item['answer'].lower()==resAns.lower()]
+ acc = min(1, float(len(matchingAns))/3)
+ gtAcc.append(acc)
+ quesType = gts[quesId]['question_type']
+ ansType = gts[quesId]['answer_type']
+ avgGTAcc = float(sum(gtAcc))/len(gtAcc)
+ accQA.append(avgGTAcc)
+ if quesType not in accQuesType:
+ accQuesType[quesType] = []
+ accQuesType[quesType].append(avgGTAcc)
+ if ansType not in accAnsType:
+ accAnsType[ansType] = []
+ accAnsType[ansType].append(avgGTAcc)
+ self.setEvalQA(quesId, avgGTAcc)
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
+ if step%100 == 0:
+ self.updateProgress(step/float(len(quesIds)))
+ step = step + 1
+
+ self.setAccuracy(accQA, accQuesType, accAnsType)
+ # print "Done computing accuracy"
+
+ def processPunctuation(self, inText):
+ outText = inText
+ for p in self.punct:
+ if (p + ' ' in inText or ' ' + p in inText) or (re.search(self.commaStrip, inText) != None):
+ outText = outText.replace(p, '')
+ else:
+ outText = outText.replace(p, ' ')
+ outText = self.periodStrip.sub("",
+ outText,
+ re.UNICODE)
+ return outText
+
+ def processDigitArticle(self, inText):
+ outText = []
+ tempText = inText.lower().split()
+ for word in tempText:
+ word = self.manualMap.setdefault(word, word)
+ if word not in self.articles:
+ outText.append(word)
+ else:
+ pass
+ for wordId, word in enumerate(outText):
+ if word in self.contractions:
+ outText[wordId] = self.contractions[word]
+ outText = ' '.join(outText)
+ return outText
+
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
+ self.accuracy['overall'] = round(100*float(sum(accQA))/len(accQA), self.n)
+ self.accuracy['perQuestionType'] = {quesType: round(100*float(sum(accQuesType[quesType]))/len(accQuesType[quesType]), self.n) for quesType in accQuesType}
+ self.accuracy['perAnswerType'] = {ansType: round(100*float(sum(accAnsType[ansType]))/len(accAnsType[ansType]), self.n) for ansType in accAnsType}
+
+ def setEvalQA(self, quesId, acc):
+ self.evalQA[quesId] = round(100*acc, self.n)
+
+ def setEvalQuesType(self, quesId, quesType, acc):
+ if quesType not in self.evalQuesType:
+ self.evalQuesType[quesType] = {}
+ self.evalQuesType[quesType][quesId] = round(100*acc, self.n)
+
+ def setEvalAnsType(self, quesId, ansType, acc):
+ if ansType not in self.evalAnsType:
+ self.evalAnsType[ansType] = {}
+ self.evalAnsType[ansType][quesId] = round(100*acc, self.n)
+
+ def updateProgress(self, progress):
+ barLength = 20
+ status = ""
+ if isinstance(progress, int):
+ progress = float(progress)
+ if not isinstance(progress, float):
+ progress = 0
+ status = "error: progress var must be float\r\n"
+ if progress < 0:
+ progress = 0
+ status = "Halt...\r\n"
+ if progress >= 1:
+ progress = 1
+ status = "Done...\r\n"
+ block = int(round(barLength*progress))
+ text = "\rFinshed Percent: [{0}] {1}% {2}".format( "#"*block + "-"*(barLength-block), int(progress*100), status)
+ sys.stdout.write(text)
+ sys.stdout.flush()
diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
new file mode 100644
index 0000000..406b596
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaDemo.py
@@ -0,0 +1,73 @@
+# coding: utf-8
+
+from vqaTools.vqa import VQA
+import random
+import skimage.io as io
+import matplotlib.pyplot as plt
+import os
+
+dataDir ='../../VQA'
+versionType ='v2_' # this should be '' when using VQA v2.0 dataset
+taskType ='OpenEnded' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0
+dataType ='mscoco' # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
+dataSubType ='train2014'
+annFile ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
+quesFile ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
+imgDir = '%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
+
+# initialize VQA api for QA annotations
+vqa=VQA(annFile, quesFile)
+
+# load and display QA annotations for given question types
+"""
+All possible quesTypes for abstract and mscoco has been provided in respective text files in ../QuestionTypes/ folder.
+"""
+annIds = vqa.getQuesIds(quesTypes='how many');
+anns = vqa.loadQA(annIds)
+randomAnn = random.choice(anns)
+vqa.showQA([randomAnn])
+imgId = randomAnn['image_id']
+imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
+if os.path.isfile(imgDir + imgFilename):
+ I = io.imread(imgDir + imgFilename)
+ plt.imshow(I)
+ plt.axis('off')
+ plt.show()
+
+# load and display QA annotations for given answer types
+"""
+ansTypes can be one of the following
+yes/no
+number
+other
+"""
+annIds = vqa.getQuesIds(ansTypes='yes/no');
+anns = vqa.loadQA(annIds)
+randomAnn = random.choice(anns)
+vqa.showQA([randomAnn])
+imgId = randomAnn['image_id']
+imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
+if os.path.isfile(imgDir + imgFilename):
+ I = io.imread(imgDir + imgFilename)
+ plt.imshow(I)
+ plt.axis('off')
+ plt.show()
+
+# load and display QA annotations for given images
+"""
+Usage: vqa.getImgIds(quesIds=[], quesTypes=[], ansTypes=[])
+Above method can be used to retrieve imageIds for given question Ids or given question types or given answer types.
+"""
+ids = vqa.getImgIds()
+annIds = vqa.getQuesIds(imgIds=random.sample(ids,5));
+anns = vqa.loadQA(annIds)
+randomAnn = random.choice(anns)
+vqa.showQA([randomAnn])
+imgId = randomAnn['image_id']
+imgFilename = 'COCO_' + dataSubType + '_'+ str(imgId).zfill(12) + '.jpg'
+if os.path.isfile(imgDir + imgFilename):
+ I = io.imread(imgDir + imgFilename)
+ plt.imshow(I)
+ plt.axis('off')
+ plt.show()
+
diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py
new file mode 100644
index 0000000..072d8d9
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__init__.py
@@ -0,0 +1 @@
+__author__ = 'aagrawal'
diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__pycache__/__init__.cpython-39.pyc b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..8b60212
Binary files /dev/null and b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__pycache__/vqa.cpython-39.pyc b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__pycache__/vqa.cpython-39.pyc
new file mode 100644
index 0000000..0bb487a
Binary files /dev/null and b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/__pycache__/vqa.cpython-39.pyc differ
diff --git a/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
new file mode 100644
index 0000000..4f76961
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/PythonHelperTools/vqaTools/vqa.py
@@ -0,0 +1,179 @@
+__author__ = 'aagrawal'
+__version__ = '0.9'
+
+# Interface for accessing the VQA dataset.
+
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
+
+# The following functions are defined:
+# VQA - VQA class that loads VQA annotation file and prepares data structures.
+# getQuesIds - Get question ids that satisfy given filter conditions.
+# getImgIds - Get image ids that satisfy given filter conditions.
+# loadQA - Load questions and answers with the specified question ids.
+# showQA - Display the specified questions and answers.
+# loadRes - Load result file and create result object.
+
+# Help on each function can be accessed by: "help(COCO.function)"
+
+import json
+import datetime
+import copy
+
+
+class VQA:
+ def __init__(self, annotation_file=None, question_file=None):
+ """
+ Constructor of VQA helper class for reading and visualizing questions and answers.
+ :param annotation_file (str): location of VQA annotation file
+ :return:
+ """
+ # load dataset
+ self.dataset = {}
+ self.questions = {}
+ self.qa = {}
+ self.qqa = {}
+ self.imgToQA = {}
+ if not annotation_file == None and not question_file == None:
+ # print 'loading VQA annotations and questions into memory...'
+ time_t = datetime.datetime.utcnow()
+ dataset = json.load(open(annotation_file, 'r'))
+ questions = json.load(open(question_file, 'r'))
+ # print datetime.datetime.utcnow() - time_t
+ self.dataset = dataset
+ self.questions = questions
+ self.createIndex()
+
+ def createIndex(self):
+ imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
+ qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
+ qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
+ for ann in self.dataset['annotations']:
+ imgToQA[ann['image_id']] += [ann]
+ qa[ann['question_id']] = ann
+ for ques in self.questions['questions']:
+ qqa[ques['question_id']] = ques
+ # print 'index created!'
+
+ # create class members
+ self.qa = qa
+ self.qqa = qqa
+ self.imgToQA = imgToQA
+
+ def info(self):
+ """
+ Print information about the VQA annotation file.
+ :return:
+ """
+
+ # for key, value in self.datset['info'].items():
+ # print '%s: %s'%(key, value)
+
+ def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
+ """
+ Get question ids that satisfy given filter conditions. default skips that filter
+ :param imgIds (int array) : get question ids for given imgs
+ quesTypes (str array) : get question ids for given question types
+ ansTypes (str array) : get question ids for given answer types
+ :return: ids (int array) : integer array of question ids
+ """
+ imgIds = imgIds if type(imgIds) == list else [imgIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset['annotations']
+ else:
+ if not len(imgIds) == 0:
+ anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], [])
+ else:
+ anns = self.dataset['annotations']
+ anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
+ anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
+ ids = [ann['question_id'] for ann in anns]
+ return ids
+
+ def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
+ """
+ Get image ids that satisfy given filter conditions. default skips that filter
+ :param quesIds (int array) : get image ids for given question ids
+ quesTypes (str array) : get image ids for given question types
+ ansTypes (str array) : get image ids for given answer types
+ :return: ids (int array) : integer array of image ids
+ """
+ quesIds = quesIds if type(quesIds) == list else [quesIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset['annotations']
+ else:
+ if not len(quesIds) == 0:
+ anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa], [])
+ else:
+ anns = self.dataset['annotations']
+ anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
+ anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
+ ids = [ann['image_id'] for ann in anns]
+ return ids
+
+ def loadQA(self, ids=[]):
+ """
+ Load questions and answers with the specified question ids.
+ :param ids (int array) : integer ids specifying question ids
+ :return: qa (object array) : loaded qa objects
+ """
+ if type(ids) == list:
+ return [self.qa[id] for id in ids]
+ elif type(ids) == int:
+ return [self.qa[ids]]
+
+ def showQA(self, anns):
+ """
+ Display the specified annotations.
+ :param anns (array of object): annotations to display
+ :return: None
+ """
+ if len(anns) == 0:
+ return 0
+ for ann in anns:
+ quesId = ann['question_id']
+ print("Question: %s" % (self.qqa[quesId]['question']))
+ for ans in ann['answers']:
+ print("Answer %d: %s" % (ans['answer_id'], ans['answer']))
+
+ def loadRes(self, resFile, quesFile):
+ """
+ Load result file and return a result object.
+ :param resFile (str) : file name of result file
+ :return: res (obj) : result api object
+ """
+ res = VQA()
+ res.questions = json.load(open(quesFile))
+ res.dataset['info'] = copy.deepcopy(self.questions['info'])
+ res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
+ res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
+ res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
+ res.dataset['license'] = copy.deepcopy(self.questions['license'])
+
+ # print 'Loading and preparing results... '
+ time_t = datetime.datetime.utcnow()
+ anns = json.load(open(resFile))
+ assert type(anns) == list, 'results is not an array of objects'
+ annsQuesIds = [ann['question_id'] for ann in anns]
+ assert set(annsQuesIds) == set(self.getQuesIds()), \
+ 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
+ for ann in anns:
+ quesId = ann['question_id']
+ if res.dataset['task_type'] == 'Multiple Choice':
+ assert ann['answer'] in self.qqa[quesId][
+ 'multiple_choices'], 'predicted answer is not one of the multiple choices'
+ qaAnn = self.qa[quesId]
+ ann['image_id'] = qaAnn['image_id']
+ ann['question_type'] = qaAnn['question_type']
+ ann['answer_type'] = qaAnn['answer_type']
+ # print 'DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())
+
+ res.dataset['annotations'] = anns
+ res.createIndex()
+ return res
diff --git a/minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt b/minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt
new file mode 100644
index 0000000..44304fc
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/QuestionTypes/abstract_v002_question_types.txt
@@ -0,0 +1,81 @@
+how many
+what color is the
+is the
+where is the
+what
+what is
+are the
+what is the
+is there a
+does the
+is the woman
+is the man
+what is on the
+is it
+is the girl
+is the boy
+is the dog
+are they
+who is
+what kind of
+what color are the
+what is in the
+what is the man
+is there
+what is the woman
+what are the
+what is the boy
+are there
+what is the girl
+is this
+how
+which
+how many people are
+is the cat
+why is the
+are
+will the
+what type of
+what is the dog
+do
+is she
+does
+do the
+is
+is the baby
+are there any
+is the lady
+can
+what animal is
+where are the
+is the sun
+what are they
+did the
+what is the cat
+what is the lady
+how many clouds are
+is that
+is the little girl
+is he
+are these
+how many trees are
+how many pillows
+are the people
+why
+is the young
+how many windows are
+is this a
+what is the little
+is the tv
+how many animals are
+who
+how many pictures
+how many plants are
+how many birds are
+what color is
+what is the baby
+is anyone
+what color
+how many bushes
+is the old man
+none of the above
diff --git a/minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt b/minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt
new file mode 100644
index 0000000..9559050
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/QuestionTypes/mscoco_question_types.txt
@@ -0,0 +1,65 @@
+how many
+is the
+what
+what color is the
+what is the
+is this
+is this a
+what is
+are the
+what kind of
+is there a
+what type of
+is it
+what are the
+where is the
+is there
+does the
+what color are the
+are these
+are there
+which
+is
+what is the man
+is the man
+are
+how
+does this
+what is on the
+what does the
+how many people are
+what is in the
+what is this
+do
+what are
+are they
+what time
+what sport is
+are there any
+is he
+what color is
+why
+where are the
+what color
+who is
+what animal is
+is the woman
+is this an
+do you
+how many people are in
+what room is
+has
+is this person
+what is the woman
+can you
+why is the
+is the person
+what is the color of the
+what is the person
+could
+was
+is that a
+what number is
+what is the name
+what brand
+none of the above
diff --git a/minigpt4/common/vqa_tools/VQA/README.md b/minigpt4/common/vqa_tools/VQA/README.md
new file mode 100644
index 0000000..439d59d
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/README.md
@@ -0,0 +1,80 @@
+Python API and Evaluation Code for v2.0 and v1.0 releases of the VQA dataset.
+===================
+## VQA v2.0 release ##
+This release consists of
+- Real
+ - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
+ - 443,757 questions for training, 214,354 questions for validation and 447,793 questions for testing
+ - 4,437,570 answers for training and 2,143,540 answers for validation (10 per question)
+
+There is only one type of task
+- Open-ended task
+
+## VQA v1.0 release ##
+This release consists of
+- Real
+ - 82,783 MS COCO training images, 40,504 MS COCO validation images and 81,434 MS COCO testing images (images are obtained from [MS COCO website] (http://mscoco.org/dataset/#download))
+ - 248,349 questions for training, 121,512 questions for validation and 244,302 questions for testing (3 per image)
+ - 2,483,490 answers for training and 1,215,120 answers for validation (10 per question)
+- Abstract
+ - 20,000 training images, 10,000 validation images and 20,000 MS COCO testing images
+ - 60,000 questions for training, 30,000 questions for validation and 60,000 questions for testing (3 per image)
+ - 600,000 answers for training and 300,000 answers for validation (10 per question)
+
+There are two types of tasks
+- Open-ended task
+- Multiple-choice task (18 choices per question)
+
+## Requirements ##
+- python 2.7
+- scikit-image (visit [this page](http://scikit-image.org/docs/dev/install.html) for installation)
+- matplotlib (visit [this page](http://matplotlib.org/users/installing.html) for installation)
+
+## Files ##
+./Questions
+- For v2.0, download the question files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
+- For v1.0, both real and abstract, question files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
+- Question files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
+ - [training question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Train_mscoco.zip)
+ - [validation question files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Questions_Val_mscoco.zip)
+- Question files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Questions_Train_mscoco.zip).
+
+./Annotations
+- For v2.0, download the annotations files from the [VQA download page](http://www.visualqa.org/download.html), extract them and place in this folder.
+- For v1.0, for both real and abstract, annotation files can be found on the [VQA v1 download page](http://www.visualqa.org/vqa_v1_download.html).
+- Annotation files from Beta v0.9 release (123,287 MSCOCO train and val images, 369,861 questions, 3,698,610 answers) can be found below
+ - [training annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Train_mscoco.zip)
+ - [validation annotation files](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.9/Annotations_Val_mscoco.zip)
+- Annotation files from Beta v0.1 release (10k MSCOCO images, 30k questions, 300k answers) can be found [here](http://visualqa.org/data/mscoco/prev_rel/Beta_v0.1/Annotations_Train_mscoco.zip).
+
+./Images
+- For real, create a directory with name mscoco inside this directory. For each of train, val and test, create directories with names train2014, val2014 and test2015 respectively inside mscoco directory, download respective images from [MS COCO website](http://mscoco.org/dataset/#download) and place them in respective folders.
+- For abstract, create a directory with name abstract_v002 inside this directory. For each of train, val and test, create directories with names train2015, val2015 and test2015 respectively inside abstract_v002 directory, download respective images from [VQA download page](http://www.visualqa.org/download.html) and place them in respective folders.
+
+./PythonHelperTools
+- This directory contains the Python API to read and visualize the VQA dataset
+- vqaDemo.py (demo script)
+- vqaTools (API to read and visualize data)
+
+./PythonEvaluationTools
+- This directory contains the Python evaluation code
+- vqaEvalDemo.py (evaluation demo script)
+- vqaEvaluation (evaluation code)
+
+./Results
+- OpenEnded_mscoco_train2014_fake_results.json (an example of a fake results file for v1.0 to run the demo)
+- Visit [VQA evaluation page] (http://visualqa.org/evaluation) for more details.
+
+./QuestionTypes
+- This directory contains the following lists of question types for both real and abstract questions (question types are unchanged from v1.0 to v2.0). In a list, if there are question types of length n+k and length n with the same first n words, then the question type of length n does not include questions that belong to the question type of length n+k.
+- mscoco_question_types.txt
+- abstract_v002_question_types.txt
+
+## References ##
+- [VQA: Visual Question Answering](http://visualqa.org/)
+- [Microsoft COCO](http://mscoco.org/)
+
+## Developers ##
+- Aishwarya Agrawal (Virginia Tech)
+- Code for API is based on [MSCOCO API code](https://github.com/pdollar/coco).
+- The format of the code for evaluation is based on [MSCOCO evaluation code](https://github.com/tylin/coco-caption).
diff --git a/minigpt4/common/vqa_tools/VQA/license.txt b/minigpt4/common/vqa_tools/VQA/license.txt
new file mode 100644
index 0000000..f87c06b
--- /dev/null
+++ b/minigpt4/common/vqa_tools/VQA/license.txt
@@ -0,0 +1,30 @@
+Copyright (c) 2014, Aishwarya Agrawal
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+2. Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
+FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+The views and conclusions contained in the software and documentation are
+those
+of the authors and should not be interpreted as representing official
+policies,
+either expressed or implied, of the FreeBSD Project.
diff --git a/minigpt4/common/vqa_tools/__init__.py b/minigpt4/common/vqa_tools/__init__.py
new file mode 100644
index 0000000..9b98da8
--- /dev/null
+++ b/minigpt4/common/vqa_tools/__init__.py
@@ -0,0 +1,8 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+__author__ = "aagrawal"
diff --git a/minigpt4/common/vqa_tools/__pycache__/__init__.cpython-39.pyc b/minigpt4/common/vqa_tools/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..dd5fda2
Binary files /dev/null and b/minigpt4/common/vqa_tools/__pycache__/__init__.cpython-39.pyc differ
diff --git a/minigpt4/common/vqa_tools/__pycache__/vqa.cpython-39.pyc b/minigpt4/common/vqa_tools/__pycache__/vqa.cpython-39.pyc
new file mode 100644
index 0000000..ac761c5
Binary files /dev/null and b/minigpt4/common/vqa_tools/__pycache__/vqa.cpython-39.pyc differ
diff --git a/minigpt4/common/vqa_tools/__pycache__/vqa_eval.cpython-39.pyc b/minigpt4/common/vqa_tools/__pycache__/vqa_eval.cpython-39.pyc
new file mode 100644
index 0000000..d12fb20
Binary files /dev/null and b/minigpt4/common/vqa_tools/__pycache__/vqa_eval.cpython-39.pyc differ
diff --git a/minigpt4/common/vqa_tools/aokvqa/LICENSE b/minigpt4/common/vqa_tools/aokvqa/LICENSE
new file mode 100644
index 0000000..663d675
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2022 Allen Institute for Artificial Intelligence
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/minigpt4/common/vqa_tools/aokvqa/README.md b/minigpt4/common/vqa_tools/aokvqa/README.md
new file mode 100644
index 0000000..21caefa
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/README.md
@@ -0,0 +1,207 @@
+# A-OKVQA
+
+Official repository for **A-OKVQA: A Benchmark for Visual Question Answering using World Knowledge**.
+
+Links: [[Paper]](https://arxiv.org/abs/2206.01718) [[Website]](http://a-okvqa.allenai.org) [[Leaderboard]](https://leaderboard.allenai.org/a-okvqa/submissions/public)
+
+### Abstract
+
+The Visual Question Answering (VQA) task aspires to provide a meaningful testbed for the development of AI models that can jointly reason over visual and natural language inputs. Despite a proliferation of VQA datasets, this goal is hindered by a set of common limitations. These include a reliance on relatively simplistic questions that are repetitive in both concepts and linguistic structure, little world knowledge needed outside of the paired image, and limited reasoning required to arrive at the correct answer. We introduce A-OKVQA, a crowdsourced dataset composed of a diverse set of about 25K questions requiring a broad base of commonsense and world knowledge to answer. In contrast to the existing knowledge-based VQA datasets, the questions generally cannot be answered by simply querying a knowledge base, and instead require some form of commonsense reasoning about the scene depicted in the image. We demonstrate the potential of this new dataset through a detailed analysis of its contents and baseline performance measurements over a variety of state-of-the-art vision–language models.
+
+
+
+
+
+#### Table of Contents
+
+- [Getting started](#getting-started)
+ * [Downloading the dataset](#downloading-the-dataset)
+- [Evaluation & Leaderboard](#evaluation)
+- [Codebase](#codebase)
+ * [Preparing data](#preparing-data)
+ * [Models and Predictions](#models-and-predictions)
+
+
+
+## Getting started
+
+```bash
+git clone --single-branch --recurse-submodules https://github.com/allenai/aokvqa.git
+
+cd aokvqa
+export PYTHONPATH=.
+
+conda env create --name aokvqa
+conda activate aokvqa
+```
+
+### Downloading the dataset
+
+```bash
+export AOKVQA_DIR=./datasets/aokvqa/
+mkdir -p ${AOKVQA_DIR}
+
+curl -fsSL https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz | tar xvz -C ${AOKVQA_DIR}
+```
+
+ Downloading COCO 2017
+
+```bash
+export COCO_DIR=./datasets/coco/
+mkdir -p ${COCO_DIR}
+
+for split in train val test; do
+ wget "http://images.cocodataset.org/zips/${split}2017.zip"
+ unzip "${split}2017.zip" -d ${COCO_DIR}; rm "${split}2017.zip"
+done
+
+wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
+unzip annotations_trainval2017.zip -d ${COCO_DIR}; rm annotations_trainval2017.zip
+```
+
+
+
+Loading our dataset is easy! Just grab our [load_aokvqa.py](https://github.com/allenai/aokvqa/blob/main/load_aokvqa.py) file and refer to the following code.
+
+```python
+import os
+aokvqa_dir = os.getenv('AOKVQA_DIR')
+
+from load_aokvqa import load_aokvqa, get_coco_path
+train_dataset = load_aokvqa(aokvqa_dir, 'train') # also 'val' or 'test'
+```
+
+ Example dataset entry
+
+```python
+dataset_example = train_dataset[0]
+
+print(dataset_example['question_id'])
+# 22MexNkBPpdZGX6sxbxVBH
+
+coco_dir = os.getenv('COCO_DIR')
+image_path = get_coco_path('train', dataset_example['image_id'], coco_dir)
+print(image_path)
+# ./datasets/coco/train2017/000000299207.jpg
+
+print(dataset_example['question'])
+print(dataset_example['choices'])
+# What is the man by the bags awaiting?
+# ['skateboarder', 'train', 'delivery', 'cab']
+
+correct_choice = dataset_example['choices'][ dataset_example['correct_choice_idx'] ]
+# Corrrect: cab
+
+print(dataset_example['rationales'][0])
+# A train would not be on the street, he would not have luggage waiting for a delivery, and the skateboarder is there and not paying attention to him so a cab is the only possible answer.
+```
+
+
+
+## Evaluation
+
+Please prepare `predictions_{split}.json` files (for `split: {val,test}`) in the format below. You may omit either `multiple_choice` or `direct_answer` field if you only want to evaluate one setting.
+
+```python
+{
+ '' : {
+ 'multiple_choice' : '',
+ 'direct_answer' : ''
+ }
+}
+```
+
+You can run evaluation on the validation set as follows.
+
+```bash
+python evaluation/eval_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --preds ./predictions_val.json
+```
+
+### Leaderboard
+
+You may submit `predictions_test.json` to the [leaderboard](https://leaderboard.allenai.org/a-okvqa/submissions/get-started).
+
+## Codebase
+
+We provide all code and pretrained models necessary to replicate our experiments for Large-Scale Pretrained Models (sec. 5.2) and Rationale Generation (sec. 5.3).
+
+### Preparing data
+
+```bash
+export FEATURES_DIR=./features/
+mkdir -p ${FEATURES_DIR}
+```
+
+You can compute CLIP features for our vocabulary and dataset. These are most commonly used by our other experiments.
+
+```bash
+python data_scripts/encode_vocab_clip.py --vocab ${AOKVQA_DIR}/large_vocab_train.csv --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt
+
+for split in train val test; do
+ python data_scripts/extract_clip_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --model-type ViT-B/32 --out ${FEATURES_DIR}/clip-ViT-B-32_${split}.pt
+done
+```
+
+ For training ClipCap with a transformer mapping network
+
+If you want to train our ClipCap models with the transformer mapping network (instead of an MLP, like we do), you'll also need to run `extract_clip_features.py` with `--model-type RN50x4`.
+
+
+
+ For ResNet and BERT input features
+
+Our ResNet and BERT classification experiments require these respective features instead of CLIP. To generate these, please run the following commands:
+
+```bash
+# ResNet
+for split in train val test; do
+ python data_scripts/extract_resnet_features.py --aokvqa-dir ${AOKVQA_DIR} --coco-dir ${COCO_DIR} --split ${split} --out ${FEATURES_DIR}/resnet_${split}.pt
+done
+
+# BERT
+for split in train val test; do
+ python data_scripts/extract_bert_features.py --aokvqa-dir ${AOKVQA_DIR} --split ${split} --out ${FEATURES_DIR}/bert_${split}.pt
+done
+```
+
+
+
+### Models and Predictions
+
+```bash
+export LOG_DIR=./logs/
+export PREDS_DIR=./predictions/
+export PT_MODEL_DIR=./pretrained_models/
+mkdir -p ${LOG_DIR} ${PREDS_DIR} ${PT_MODEL_DIR}
+```
+
+ Download our pretrained model weights
+
+```bash
+# Checkpoints for transfer learning experiments
+curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/transfer_exp_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
+
+# Checkpoints for ClipCap models (generating answers and rationales)
+curl -fsSL https://prior-model-weights.s3.us-east-2.amazonaws.com/aokvqa/clipcap_checkpoints.tar.gz | tar xvz -C ${PT_MODEL_DIR}/aokvqa_models
+```
+
+
+
+We have included instructions for replicating each of our experiments (see README.md files below).
+
+All Python scripts should be run from the root of this repository. Please be sure to first run the installation and data preparation as directed above.
+
+- [Heuristics](./heuristics/README.md)
+- [Transfer Learning Experiments](./transfer_experiments/README.md)
+- [Querying GPT-3](./gpt3/README.md)
+- [ClipCap](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
+- [Generating Captions & Rationales](https://github.com/allenai/aokvqa/blob/ClipCap/README.md)
+
+For each experiment, we follow this prediction file naming scheme: `{model-name}_{split}-{setting}.json` (e.g. `random-weighted_val-mc.json` or `random-weighted_test-da.json`). As examples in these Readme files, we produce predictions on the validation set.
+
+We unify predictions for each split before evaluation. (You can omit one of `--mc` or `--da` prediction file if you only want to evaluate one setting.)
+
+```bash
+python evaluation/prepare_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc ./predictions_val-mc.json --da ./predictions_val-da.json --out ./predictions_val.json
+# repeat for test split ...
+```
diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py
new file mode 100644
index 0000000..2c44686
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/build_vocab.py
@@ -0,0 +1,45 @@
+import os
+import argparse
+from collections import Counter
+import pathlib
+
+from load_aokvqa import load_aokvqa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
+args = parser.parse_args()
+
+
+# Build vocab from train set: correct choices + (direct answers appearing in >= 3 )
+
+train_set = load_aokvqa(args.aokvqa_dir, 'train')
+
+vocab = []
+all_choices = Counter()
+direct_answers = Counter()
+
+for i in train_set:
+ vocab.append( i['choices'][i['correct_choice_idx']] )
+ all_choices.update(i['choices'])
+ direct_answers.update(set(i['direct_answers']))
+vocab += [k for k,v in all_choices.items() if v >= 3]
+vocab += [k for k,v in direct_answers.items() if v >= 3]
+
+vocab = sorted(set(vocab))
+print(f"Vocab size: {len(vocab)}")
+
+# Save vocabulary Output
+
+with open(args.output_file, 'w') as f:
+ for v in vocab:
+ print(v, file=f)
+
+## Check validation set coverage
+
+val_set = load_aokvqa(args.aokvqa_dir, 'val')
+
+val_acc = [v['choices'][v['correct_choice_idx']] in vocab for v in val_set]
+val_acc = sum(val_acc) / len(val_acc) * 100
+print(f"Val set coverage: {val_acc:.2f}" )
diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py
new file mode 100644
index 0000000..1dce760
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/encode_vocab_clip.py
@@ -0,0 +1,26 @@
+import json
+from tqdm import tqdm
+import argparse
+import pathlib
+
+import torch
+import clip
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--vocab', type=pathlib.Path, required=True, dest='vocab_file')
+parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
+parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
+args = parser.parse_args()
+
+assert args.output_file.suffix == '.pt'
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model, preprocess = clip.load(args.model_type, device=device)
+
+with torch.no_grad():
+ a = open(args.vocab_file).read().splitlines()
+ mc_text = clip.tokenize(a).to(device)
+ mc_text_features = torch.stack([model.encode_text(mct.unsqueeze(0)).cpu() for mct in tqdm(mc_text)], dim=1)[0]
+ mc_text_features = mc_text_features.float()
+ model_name = args.model_type.replace('/', '-').replace('@', '-')
+ torch.save(mc_text_features, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py
new file mode 100644
index 0000000..60cd40f
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_bert_features.py
@@ -0,0 +1,50 @@
+import os
+import argparse
+import pathlib
+from tqdm import tqdm
+
+import torch
+from transformers import AutoTokenizer, AutoModel
+
+from load_aokvqa import load_aokvqa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
+args = parser.parse_args()
+
+assert args.output_file.suffix == '.pt'
+
+## Load dataset
+
+dataset = load_aokvqa(args.aokvqa_dir, args.split)
+
+## Load model
+
+tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
+model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens')
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+model.eval()
+
+def mean_pooling(model_output, attention_mask):
+ token_embeddings = model_output[0] # First element of model_output contains all token embeddings
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
+
+## Encoding loop
+
+with torch.no_grad():
+ embeddings = {}
+
+ for d in tqdm(dataset):
+ encoded_input = tokenizer([d['question']], padding=True, return_tensors='pt')
+ encoded_input = {k:v.to(device) for k,v in encoded_input.items()}
+ e = mean_pooling(model(**encoded_input), encoded_input['attention_mask'])
+ embeddings[d['question_id']] = {
+ 'question' : e[0].cpu()
+ }
+
+ torch.save(embeddings, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py
new file mode 100644
index 0000000..20d0455
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_clip_features.py
@@ -0,0 +1,51 @@
+import os
+from PIL import Image
+from tqdm import tqdm
+import argparse
+import pathlib
+
+import torch
+import clip
+
+from load_aokvqa import load_aokvqa, get_coco_path
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
+parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+parser.add_argument('--model-type', type=str, choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'], required=True, dest='model_type')
+parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
+args = parser.parse_args()
+
+assert args.output_file.suffix == '.pt'
+
+## Load dataset
+
+dataset = load_aokvqa(args.aokvqa_dir, args.split)
+
+## Load model
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model, preprocess = clip.load(args.model_type, device=device)
+
+## Encoding loop
+
+with torch.no_grad():
+ embeddings = {}
+
+ for d in tqdm(dataset):
+ q = d["question"]
+ q_text = clip.tokenize(q).to(device)
+ q_text_features = model.encode_text(q_text)
+
+ img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir))
+ img = preprocess(img).unsqueeze(0).to(device)
+ image_features = model.encode_image(img)
+
+ embeddings[d['question_id']] = {
+ 'question' : q_text_features[0].float().cpu(),
+ 'image' : image_features[0].float().cpu(),
+ }
+
+ torch.save(embeddings, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py
new file mode 100644
index 0000000..0d7277b
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/data_scripts/extract_resnet_features.py
@@ -0,0 +1,62 @@
+import os
+import argparse
+import pathlib
+from tqdm import tqdm
+from PIL import Image
+
+import torch
+import torch.nn as nn
+from torchvision import models
+from torchvision import transforms as T
+
+from load_aokvqa import load_aokvqa, get_coco_path
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
+parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+parser.add_argument('--out', type=pathlib.Path, required=True, dest='output_file')
+args = parser.parse_args()
+
+assert args.output_file.suffix == '.pt'
+
+## Load dataset
+
+dataset = load_aokvqa(args.aokvqa_dir, args.split)
+
+## Load model
+
+resnet_preprocess = T.Compose([
+ T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC),
+ T.CenterCrop(size=(224, 224)),
+ T.ToTensor(),
+ T.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225]
+ )
+])
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+resnet_model = models.resnet50(pretrained=True)
+resnet_model = torch.nn.Sequential(
+ *list(resnet_model.children())[:-1],
+ nn.Flatten()
+) # strip classification layer
+resnet_model = resnet_model.to(device)
+
+## Encoding loop
+
+with torch.no_grad():
+ embeddings = {}
+
+ for d in tqdm(dataset):
+ img = Image.open(get_coco_path(args.split, d['image_id'], args.coco_dir)).convert('RGB')
+ resnet_input = resnet_preprocess(img).unsqueeze(0).to(device)
+ resnet_features = resnet_model(resnet_input)
+ embeddings[d['question_id']] = {
+ 'image' : resnet_features[0].cpu()
+ }
+
+ torch.save(embeddings, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/environment.yml b/minigpt4/common/vqa_tools/aokvqa/environment.yml
new file mode 100644
index 0000000..58284ec
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/environment.yml
@@ -0,0 +1,36 @@
+name: aokvqa
+channels:
+ - pytorch
+ - nvidia
+ - huggingface
+ - conda-forge
+ - defaults
+dependencies:
+ - python=3.7
+ - cudatoolkit=11.3
+ - numpy=1.21.6
+ - pytorch=1.11.0
+ - torchvision=0.12.0
+ - pytorch-lightning=1.6.3
+ - torchmetrics=0.8.1
+ - gdown=4.4.0
+ - pip=22.0.4
+ - pip:
+ - argparse==1.4.0
+ - Pillow==9.0.1
+ - tensorboard==2.9.0
+ - ftfy==6.1.1
+ - regex==2022.3.15
+ - tqdm==4.64.0
+ - clip @ git+https://github.com/openai/CLIP.git@b46f5ac7587d2e1862f8b7b1573179d80dcdd620
+ - openai==0.18.1
+ - nltk==3.7
+ - sacrebleu==2.0.0
+ - sacremoses==0.0.53
+ - sentence-transformers==2.2.0
+ - datasets==2.1.0
+ - tokenizers==0.10.3
+ - transformers==4.10.3
+
+# Next: resolve conflict between sentence-transfomers and pytorch-lightning
+# pip uninstall sentencepiece
diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/__pycache__/load_aokvqa.cpython-39.pyc b/minigpt4/common/vqa_tools/aokvqa/evaluation/__pycache__/load_aokvqa.cpython-39.pyc
new file mode 100644
index 0000000..9e2afd5
Binary files /dev/null and b/minigpt4/common/vqa_tools/aokvqa/evaluation/__pycache__/load_aokvqa.cpython-39.pyc differ
diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py b/minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py
new file mode 100644
index 0000000..a7b5dbe
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/evaluation/eval_predictions.py
@@ -0,0 +1,97 @@
+import argparse
+import pathlib
+import json
+import glob
+
+from load_aokvqa import load_aokvqa
+
+
+def eval_aokvqa(dataset, preds, multiple_choice=False, strict=True):
+
+ if isinstance(dataset, list):
+ dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
+
+ if multiple_choice is False:
+ dataset = {k:v for k,v in dataset.items() if v['difficult_direct_answer'] is False}
+
+ if strict:
+ dataset_qids = set(dataset.keys())
+ preds_qids = set(preds.keys())
+ assert dataset_qids.issubset(preds_qids)
+
+ # dataset = q_id (str) : dataset element (dict)
+ # preds = q_id (str) : prediction (str)
+
+ acc = []
+
+ for q in dataset.keys():
+ if q not in preds.keys():
+ acc.append(0.0)
+ continue
+
+ pred = preds[q]
+ choices = dataset[q]['choices']
+ direct_answers = dataset[q]['direct_answers']
+
+ ## Multiple Choice setting
+ if multiple_choice:
+ if strict:
+ assert pred in choices, 'Prediction must be a valid choice'
+ correct_choice_idx = dataset[q]['correct_choice_idx']
+ acc.append( float(pred == choices[correct_choice_idx]) )
+ ## Direct Answer setting
+ else:
+ num_match = sum([pred.lower() == da.lower() for da in direct_answers])
+ vqa_acc = min(1.0, num_match / 3.0)
+ acc.append(vqa_acc)
+
+ acc = sum(acc) / len(acc) * 100
+
+ return acc
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+ parser.add_argument('--preds', type=str, required=True, dest='prediction_files')
+ args = parser.parse_args()
+
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
+
+ for prediction_file in glob.glob(args.prediction_files):
+ predictions = json.load(open(prediction_file, 'r'))
+
+ # Multiple choice
+
+ mc_predictions = {}
+
+ for q in predictions.keys():
+ if 'multiple_choice' in predictions[q].keys():
+ mc_predictions[q] = predictions[q]['multiple_choice']
+
+ if mc_predictions != {}:
+ mc_acc = eval_aokvqa(
+ dataset,
+ mc_predictions,
+ multiple_choice=True,
+ strict=False
+ )
+ print(prediction_file, 'MC', mc_acc)
+
+ # Direct Answer
+
+ da_predictions = {}
+
+ for q in predictions.keys():
+ if 'direct_answer' in predictions[q].keys():
+ da_predictions[q] = predictions[q]['direct_answer']
+
+ if da_predictions != {}:
+ da_acc = eval_aokvqa(
+ dataset,
+ da_predictions,
+ multiple_choice=False,
+ strict=False
+ )
+ print(prediction_file, 'DA', da_acc)
diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py b/minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py
new file mode 100644
index 0000000..3e3dd49
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/evaluation/load_aokvqa.py
@@ -0,0 +1,13 @@
+import os
+import json
+
+
+def load_aokvqa(aokvqa_dir, split, version='v1p0'):
+ assert split in ['train', 'val', 'test', 'test_w_ans']
+ dataset = json.load(open(
+ os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
+ ))
+ return dataset
+
+def get_coco_path(split, image_id, coco_dir):
+ return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")
diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py b/minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py
new file mode 100644
index 0000000..202f00c
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/evaluation/prepare_predictions.py
@@ -0,0 +1,31 @@
+import argparse
+import pathlib
+import json
+
+from load_aokvqa import load_aokvqa
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+ parser.add_argument('--mc', type=argparse.FileType('r'), dest='mc_pred_file')
+ parser.add_argument('--da', type=argparse.FileType('r'), dest='da_pred_file')
+ parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
+ args = parser.parse_args()
+ assert args.mc_pred_file or args.da_pred_file
+
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
+ mc_preds = json.load(args.mc_pred_file) if args.mc_pred_file else None
+ da_preds = json.load(args.da_pred_file) if args.da_pred_file else None
+ predictions = {}
+
+ for d in dataset:
+ q = d['question_id']
+ predictions[q] = {}
+ if mc_preds and q in mc_preds.keys():
+ predictions[q]['multiple_choice'] = mc_preds[q]
+ if da_preds and q in da_preds.keys():
+ predictions[q]['direct_answer'] = da_preds[q]
+
+ json.dump(predictions, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py b/minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py
new file mode 100644
index 0000000..40ba155
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/evaluation/remap_predictions.py
@@ -0,0 +1,44 @@
+import argparse
+import pathlib
+import json
+from tqdm import tqdm
+
+from sentence_transformers import SentenceTransformer
+from sentence_transformers.util import cos_sim
+
+from load_aokvqa import load_aokvqa
+
+
+def map_to_choices(dataset, predictions, device='cpu'):
+ if isinstance(dataset, list):
+ dataset = { dataset[i]['question_id'] : dataset[i] for i in range(len(dataset)) }
+
+ if all([p in dataset[q]['choices'] for q, p in predictions.items()]):
+ return predictions
+
+ model = SentenceTransformer('sentence-transformers/average_word_embeddings_glove.6B.300d')
+ model.to(device)
+ for q in tqdm(predictions.keys()):
+ choices = dataset[q]['choices']
+ if predictions[q] not in choices:
+ choice_embeddings = model.encode([predictions[q]] + choices, convert_to_tensor=True)
+ a_idx = cos_sim(choice_embeddings[0], choice_embeddings[1:]).argmax().item()
+ predictions[q] = choices[a_idx]
+
+ return predictions
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+ parser.add_argument('--pred', type=argparse.FileType('r'), required=True, dest='prediction_file')
+ parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
+ args = parser.parse_args()
+
+
+ dataset = load_aokvqa(args.aokvqa_dir, args.split)
+ predictions = json.load(args.prediction_file)
+ predictions = map_to_choices(dataset, predictions)
+
+ json.dump(predictions, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/gpt3/README.md b/minigpt4/common/vqa_tools/aokvqa/gpt3/README.md
new file mode 100644
index 0000000..fc1fd6b
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/gpt3/README.md
@@ -0,0 +1,14 @@
+## Querying GPT-3
+
+To follow our experiments which use GPT-3, you must have access to the [OpenAI API](https://openai.com/api/) (at cost). Please retrieve your [organization](https://beta.openai.com/account/org-settings) and [API](https://beta.openai.com/account/api-keys) keys and set them in your environment variables.
+
+```bash
+export OPENAI_ORG=....
+export OPENAI_API_KEY=...
+```
+
+For producing predictions for both DA and MC settings, run:
+```bash
+python gpt3/query_gpt3.py --aokvqa-dir ${AOKVQA_DIR} --split val --out ${PREDS_DIR}/gpt3_val-da.json
+python remap_predictions.py --aokvqa-dir ${AOKVQA_DIR} --split val --pred ${PREDS_DIR}/gpt3_val-da.json --out ${PREDS_DIR}/gpt3_val-mc.json
+```
diff --git a/minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py b/minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py
new file mode 100644
index 0000000..2117434
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/gpt3/caption_inputs.py
@@ -0,0 +1,23 @@
+import os
+import json
+import argparse
+import pathlib
+
+from load_aokvqa import load_aokvqa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--coco-dir', type=pathlib.Path, required=True, dest='coco_dir')
+parser.add_argument('--split', type=str, choices=['train', 'val'], required=True)
+parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
+args = parser.parse_args()
+
+aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
+
+coco_captions = json.load(open(os.path.join(args.coco_dir, 'annotations', f'captions_{args.split}2017.json')))['annotations']
+coco_captions = {c['image_id'] : c['caption'] for c in coco_captions}
+
+captions = { d['question_id'] : coco_captions[d['image_id']] for d in aokvqa_set }
+
+json.dump(captions, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py b/minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py
new file mode 100644
index 0000000..4a08900
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/gpt3/query_gpt3.py
@@ -0,0 +1,79 @@
+import os
+import random
+import json
+from tqdm import tqdm
+import argparse
+import pathlib
+
+import openai
+openai.organization = os.getenv('OPENAI_ORG')
+openai.api_key = os.getenv('OPENAI_API_KEY')
+
+from load_aokvqa import load_aokvqa
+
+
+random.seed(0)
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+ parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+ parser.add_argument('--n', type=int, default=10, dest='num_examples')
+ parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
+ parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
+ parser.add_argument('--include-choices', action='store_true', dest='include_choices')
+ parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
+ parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
+ args = parser.parse_args()
+
+
+ train_set = load_aokvqa(args.aokvqa_dir, 'train')
+ eval_set = load_aokvqa(args.aokvqa_dir, args.split)
+
+ train_context = {}
+ context = {}
+ if args.context_file is not None:
+ train_context = json.load(args.train_context_file)
+ context = json.load(args.context_file)
+
+ predictions = {}
+
+ for d in tqdm(eval_set):
+ q = d['question_id']
+
+ prompt = args.prompt_prefix
+ for e in random.sample(train_set, args.num_examples):
+ prompt += prompt_element(e,
+ context=train_context.get(q, None),
+ include_choices=args.include_choices,
+ answer=True
+ )
+ prompt += '\n\n'
+
+ prompt += prompt_element(d,
+ context=context.get(q, None),
+ include_choices=args.include_choices,
+ answer=False
+ )
+
+ response = openai.Completion.create(
+ engine="text-curie-001",
+ prompt=prompt,
+ temperature=0.0,
+ max_tokens=10,
+ )
+
+ predictions[q] = response.choices[0].text.strip()
+
+ json.dump(predictions, args.output_file)
+
+
+def prompt_element(d, context=None, include_choices=False, answer=False):
+ return (f"Context: {context}\n" if context is not None else '') + \
+ f"Q: {d['question']}\n" + \
+ (f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \
+ f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '')
+
+if __name__ == '__main__':
+ main()
diff --git a/minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py b/minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py
new file mode 100644
index 0000000..411d1ee
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/gpt3/rationale_inputs.py
@@ -0,0 +1,16 @@
+import json
+import argparse
+import pathlib
+
+from load_aokvqa import load_aokvqa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--split', type=str, choices=['train', 'val', 'test_w_ans'], required=True)
+parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
+args = parser.parse_args()
+
+aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
+rationales = {d['question_id'] : d['rationales'][0] for d in aokvqa_set}
+json.dump(rationales, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/heuristics/README.md b/minigpt4/common/vqa_tools/aokvqa/heuristics/README.md
new file mode 100644
index 0000000..67c8632
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/heuristics/README.md
@@ -0,0 +1,11 @@
+## Heuristics
+
+```bash
+# These scripts accept the same arguments.
+# heuristics/random_unweighted.py
+# heuristics/random_weighted.py
+# heuristics/most_common_answer.py
+
+python heuristics/random_unweighted.py --aokvqa-dir ${AOKVQA_DIR} --split val --mc --out ${PREDS_DIR}/random-unweighted_val-mc.json
+# Exclude --mc for the direct answer setting
+```
diff --git a/minigpt4/common/vqa_tools/aokvqa/heuristics/most_common_answer.py b/minigpt4/common/vqa_tools/aokvqa/heuristics/most_common_answer.py
new file mode 100644
index 0000000..59a27bc
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/heuristics/most_common_answer.py
@@ -0,0 +1,39 @@
+import os
+import json
+import argparse
+import pathlib
+from collections import Counter
+
+from load_aokvqa import load_aokvqa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+parser.add_argument('--mc', action='store_true', dest='multiple_choice')
+parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
+args = parser.parse_args()
+
+
+train_set = load_aokvqa(args.aokvqa_dir, 'train')
+train_freq = dict(Counter(
+ [d['choices'][d['correct_choice_idx']] for d in train_set]
+))
+most_common_answer = max(train_freq.keys(), key=train_freq.get)
+
+##
+
+eval_set = load_aokvqa(args.aokvqa_dir, args.split)
+
+predictions = {}
+
+for d in eval_set:
+ q = d['question_id']
+ predictions[q] = most_common_answer
+
+ if args.multiple_choice:
+ choices = [c for c in d['choices'] if c in train_freq.keys()]
+ if len(choices) > 0:
+ predictions[q] = max(choices, key=train_freq.get)
+
+json.dump(predictions, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/heuristics/random_unweighted.py b/minigpt4/common/vqa_tools/aokvqa/heuristics/random_unweighted.py
new file mode 100644
index 0000000..cfcf900
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/heuristics/random_unweighted.py
@@ -0,0 +1,38 @@
+import os
+import json
+from random import seed, sample
+import argparse
+import pathlib
+
+from load_aokvqa import load_aokvqa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+parser.add_argument('--mc', action='store_true', dest='multiple_choice')
+parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
+args = parser.parse_args()
+
+seed(0)
+
+train_set = load_aokvqa(args.aokvqa_dir, 'train')
+
+if args.multiple_choice is False:
+ choices = list(set(
+ [d['choices'][d['correct_choice_idx']] for d in train_set]
+ ))
+
+##
+
+predictions = {}
+
+eval_set = load_aokvqa(args.aokvqa_dir, args.split)
+
+for d in eval_set:
+ q = d['question_id']
+ if args.multiple_choice:
+ choices = d['choices']
+ predictions[q] = sample(choices, 1)[0]
+
+json.dump(predictions, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/heuristics/random_weighted.py b/minigpt4/common/vqa_tools/aokvqa/heuristics/random_weighted.py
new file mode 100644
index 0000000..2ccfa61
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/heuristics/random_weighted.py
@@ -0,0 +1,46 @@
+import os
+import json
+import numpy as np
+import argparse
+import pathlib
+from collections import Counter
+
+from load_aokvqa import load_aokvqa
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+parser.add_argument('--mc', action='store_true', dest='multiple_choice')
+parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
+args = parser.parse_args()
+
+np.random.seed(0)
+
+train_set = load_aokvqa(args.aokvqa_dir, 'train')
+train_freq = dict(Counter(
+ [d['choices'][d['correct_choice_idx']] for d in train_set]
+))
+
+if args.multiple_choice is False:
+ choices = list(train_freq.keys())
+ probs = [f / len(train_set) for f in train_freq.values()]
+
+##
+
+predictions = {}
+
+eval_set = load_aokvqa(args.aokvqa_dir, args.split)
+
+for d in eval_set:
+ if args.multiple_choice:
+ choices = d['choices']
+ probs = [train_freq.get(c, 0) for c in choices]
+ if probs == [0, 0, 0, 0]:
+ probs = [1, 1, 1, 1]
+ probs = [p / sum(probs) for p in probs]
+
+ q = d['question_id']
+ predictions[q] = np.random.choice(choices, size=1, p=probs)[0]
+
+json.dump(predictions, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/load_aokvqa.py b/minigpt4/common/vqa_tools/aokvqa/load_aokvqa.py
new file mode 100644
index 0000000..3e3dd49
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/load_aokvqa.py
@@ -0,0 +1,13 @@
+import os
+import json
+
+
+def load_aokvqa(aokvqa_dir, split, version='v1p0'):
+ assert split in ['train', 'val', 'test', 'test_w_ans']
+ dataset = json.load(open(
+ os.path.join(aokvqa_dir, f"aokvqa_{version}_{split}.json")
+ ))
+ return dataset
+
+def get_coco_path(split, image_id, coco_dir):
+ return os.path.join(coco_dir, f"{split}2017", f"{image_id:012}.jpg")
diff --git a/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/README.md b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/README.md
new file mode 100644
index 0000000..dc5138d
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/README.md
@@ -0,0 +1,41 @@
+## Transfer Learning Experiments
+
+We use the following training/prediction scripts for the classifier, zero-shot, and contrastive experiments in Table 3.
+
+```bash
+## Training
+python transfer_experiments/train.py --aokvqa-dir ${AOKVQA_DIR} --vocab ${AOKVQA_DIR}/large_vocab_train.csv --log-dir ${LOG_DIR}
+
+--backbone clip --clip-model-type ViT-B/32 --train-features ${FEATURES_DIR}/clip-ViT-B-32_train.pt --val-features ${FEATURES_DIR}/clip-ViT-B-32_val.pt
+--inputs question # OR --inputs image # OR --inputs question image
+# OR
+--backbone resnet --train-features ${FEATURES_DIR}/resnet_train.pt --val-features ${FEATURES_DIR}/resnet_val.pt --inputs image
+# OR
+--backbone bert --train-features ${FEATURES_DIR}/bert_train.pt --val-features ${FEATURES_DIR}/bert_val.pt --inputs question
+
+--objective classifier
+# OR
+--objective contrastive --vocab-features ${FEATURE_DIR}/clip-ViT-B-32_large_vocab.pt
+```
+
+You can make predictions for CLIP zero-shot or from a classifier/contrastive checkpoint trained above.
+
+```bash
+## Predicting
+python transfer_experiments/predict.py --aokvqa-dir ${AOKVQA_DIR} --out ${PREDS_DIR}/clip-classifier_val-mc.json
+
+--split val # or test
+--features ${FEATURE_DIR}/clip-ViT-B-32_val.pt # adjust for backbone and eval split
+
+--ckpt path/to/model.ckpt
+# OR
+--zero-shot --clip-model-type ViT-B/32
+--inputs question # OR --inputs image # OR --inputs question image
+
+--mc # Multiple-choice. Exclude for direct-answer.
+
+# IF classifier OR direct-answer
+--vocab ${AOKVQA_DIR}/large_vocab_train.csv
+# IF contrastive/zero-shot AND direct-answer
+--vocab-features ${FEATURES_DIR}/clip-ViT-B-32_large_vocab.pt
+```
diff --git a/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/predict.py b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/predict.py
new file mode 100644
index 0000000..d2fbb42
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/predict.py
@@ -0,0 +1,126 @@
+import sys
+import os
+import argparse
+import pathlib
+from tqdm import tqdm
+import json
+
+import torch
+import torch.nn as nn
+
+# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
+import sentencepiece; import pytorch_lightning as pl; import clip
+
+from transfer_experiments.train import LinearClassifier
+from load_aokvqa import load_aokvqa
+from evaluation.remap_predictions import map_to_choices
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
+parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+parser.add_argument('--features', type=pathlib.Path, required=True)
+parser.add_argument('--out', type=argparse.FileType('w'), dest='output_file')
+#
+parser_weights = parser.add_mutually_exclusive_group(required=True)
+
+parser_weights.add_argument('--ckpt', type=pathlib.Path, dest='checkpoint_path')
+
+parser_weights.add_argument('--zero-shot', action='store_true', dest='clip_zero_shot')
+parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=('--zero-shot' in sys.argv))
+#
+parser.add_argument('--vocab', type=argparse.FileType('r'))
+parser.add_argument('--vocab-features', type=pathlib.Path, dest='vocab_features')
+parser.add_argument('--mc', action='store_true', dest='multiple_choice')
+
+parser.add_argument('--clip-model-type', type=str,
+ choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
+ dest='clip_model_type', required=('--zero-shot' in sys.argv and '--mc' in sys.argv))
+#
+args = parser.parse_args()
+
+
+## Load dataset
+
+aokvqa_set = load_aokvqa(args.aokvqa_dir, args.split)
+
+## Load models
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+if args.checkpoint_path is not None:
+ classifier = LinearClassifier.load_from_checkpoint(args.checkpoint_path)
+ classifier.to(device)
+ hp = classifier.hparams
+elif args.clip_zero_shot:
+ classifier = nn.Identity().to(device)
+ hp = pl.utilities.AttributeDict(backbone='clip', clip_model_type=args.clip_model_type, objective='zero-shot', inputs=args.inputs)
+
+# Load input features
+
+embeddings = torch.load(args.features)
+if hp.backbone == 'clip':
+ for q in embeddings.keys():
+ embeddings[q]['question'] = embeddings[q]['question'] / embeddings[q]['question'].norm(dim=-1, keepdim=True)
+ embeddings[q]['image'] = embeddings[q]['image'] / embeddings[q]['image'].norm(dim=-1, keepdim=True)
+
+# Load vocab, vocab features, clip
+
+if (hp.objective == 'classifier') or \
+ (hp.objective in ['contrastive', 'zero-shot'] and args.multiple_choice is False):
+ vocab = args.vocab.read().splitlines()
+
+if hp.objective in ['contrastive', 'zero-shot']:
+ if args.multiple_choice is False:
+ vocab_features = torch.load(args.vocab_features).cpu()
+ vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
+ else:
+ clip_model = clip.load(hp.clip_model_type, device=device)[0]
+ logit_scale = clip_model.logit_scale.exp().cpu()
+
+## Prediction loop
+
+predictions = {}
+
+with torch.no_grad():
+ for o in tqdm(aokvqa_set):
+ q = o['question_id']
+
+ # Load input embedding (from question / image)
+ if hp.objective == 'zero-shot' and ('question' in hp.inputs and 'image' in hp.inputs):
+ e = embeddings[q]['question'] + embeddings[q]['image']
+ elif 'question' in hp.inputs and 'image' in hp.inputs:
+ e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
+ elif 'question' in hp.inputs:
+ e = embeddings[q]['question']
+ elif 'image' in hp.inputs:
+ e = embeddings[q]['image']
+
+ # Pass inputs through model
+ e = e.unsqueeze(0).to(device)
+ x = classifier(e)[0].cpu()
+
+ # Predict
+ if hp.objective in ['contrastive', 'zero-shot']:
+ if args.multiple_choice:
+ vocab = o['choices']
+ # Encode choices
+ vocab_features = clip.tokenize(vocab).to(device)
+ vocab_features = torch.stack([
+ clip_model.encode_text(v.unsqueeze(0)) for v in vocab_features
+ ], dim=1)[0]
+ vocab_features /= vocab_features.norm(dim=-1, keepdim=True)
+ vocab_features = vocab_features.float().cpu()
+
+ x = logit_scale * x @ vocab_features.t()
+ x = x.softmax(dim=-1)
+
+ predictions[q] = vocab[x.argmax().item()]
+
+## Save and evaluate predictions
+
+# Map prediction to nearest neighbor choice (by word embeddings)
+if args.multiple_choice and hp.objective == 'classifier':
+ predictions = map_to_choices(aokvqa_set, predictions)
+
+json.dump(predictions, args.output_file)
diff --git a/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/train.py b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/train.py
new file mode 100644
index 0000000..ac48b5a
--- /dev/null
+++ b/minigpt4/common/vqa_tools/aokvqa/transfer_experiments/train.py
@@ -0,0 +1,263 @@
+import os
+import sys
+import json
+import argparse
+import pathlib
+import random
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import Dataset, DataLoader
+
+# https://github.com/PyTorchLightning/pytorch-lightning/issues/11663
+import sentencepiece; import pytorch_lightning as pl
+
+import torchmetrics.functional as MF
+
+from load_aokvqa import load_aokvqa
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
+ parser.add_argument('--vocab', type=argparse.FileType('r'), required=True)
+ parser.add_argument('--log-dir', type=pathlib.Path, dest='log_dir', required=True)
+ #
+ parser.add_argument('--backbone', type=str, choices=['clip', 'resnet', 'bert'], required=True)
+ parser.add_argument('--clip-model-type', type=str,
+ choices=['RN50', 'RN50x4', 'RN50x16', 'RN50x64', 'RN101', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'],
+ dest='clip_model_type', required=('clip' in sys.argv))
+ parser.add_argument('--train-features', type=pathlib.Path, required=True, dest='train_features')
+ parser.add_argument('--val-features', type=pathlib.Path, required=True, dest='val_features')
+ parser.add_argument('--vocab-features', type=pathlib.Path, required=('contrastive' in sys.argv), dest='vocab_features')
+ #
+ parser.add_argument('--objective', type=str, choices=['classifier', 'contrastive'], required=True)
+ parser.add_argument('--inputs', nargs='+', type=str, choices=['question', 'image'], required=True)
+ # Defaults
+ parser.add_argument('--bs', type=int, default=128, dest='batch_size')
+ parser.add_argument('--lr', type=float, default=0.01)
+ parser.add_argument('--epochs', type=int, default=500)
+ parser.add_argument('--gpus', type=int, default=1)
+ args = parser.parse_args()
+
+ pl.seed_everything(1)
+ vocab = args.vocab.read().splitlines()
+
+ ## Data loading
+
+ dm = AokvqaEmbeddingsDataModule(
+ args.aokvqa_dir,
+ args.train_features,
+ args.val_features,
+ args.objective,
+ args.backbone,
+ args.inputs,
+ vocab,
+ args.vocab_features,
+ batch_size=args.batch_size,
+ num_workers=16
+ )
+
+ ## Model definition
+
+ model = LinearClassifier(
+ args.objective,
+ args.backbone,
+ args.clip_model_type,
+ args.inputs,
+ len(vocab),
+ args.lr
+ )
+
+ ## Training and testing loops
+
+ logger = pl.loggers.TensorBoardLogger(
+ args.log_dir,
+ name=f'{args.backbone}-{args.objective}',
+ version=f"inputs:{'+'.join(args.inputs)}"
+ )
+
+ trainer = pl.Trainer(
+ logger=logger,
+ gpus=args.gpus,
+ max_epochs=args.epochs,
+ callbacks=[
+ pl.callbacks.ModelCheckpoint(
+ monitor="val_acc",
+ filename="{epoch:02d}-{val_acc:.2f}",
+ mode="max"
+ )
+ ],
+ )
+
+ trainer.fit(model, dm)
+
+
+class AokvqaEmbeddingsDataset(Dataset):
+ def __init__(self, aokvqa_dir, split, input_features, objective, backbone, inputs, vocab, vocab_features):
+
+ aokvqa_set = load_aokvqa(aokvqa_dir, split)
+
+ assert ( backbone == 'resnet' and inputs == ['image'] and objective == 'classifier' ) \
+ or ( backbone == 'bert' and inputs == ['question'] and objective == 'classifier' ) \
+ or ( backbone == 'clip' )
+
+ embeddings = torch.load(input_features)
+ if backbone == 'clip':
+ for q in embeddings.keys():
+ embeddings[q]['question'] /= embeddings[q]['question'].norm(dim=-1, keepdim=True)
+ embeddings[q]['image'] /= embeddings[q]['image'].norm(dim=-1, keepdim=True)
+ if objective == 'contrastive':
+ vocab_embeddings = torch.load(vocab_features)
+ vocab_embeddings /= vocab_embeddings.norm(dim=-1, keepdim=True)
+
+ self.objective = objective
+ self.vocab_len = len(vocab)
+
+ self.embeddings = []
+ self.answers = []
+
+ for o in aokvqa_set:
+ correct_answers = set([o['choices'][o['correct_choice_idx']]] + o['direct_answers'])
+ correct_answers = [vocab.index(a) for a in correct_answers if a in vocab]
+ if self.objective == 'contrastive':
+ correct_answers = [vocab_embeddings[a] for a in correct_answers]
+ if len(correct_answers) == 0: continue
+ self.answers.append(correct_answers)
+
+ q = o['question_id']
+ if 'question' in inputs and 'image' in inputs:
+ e = torch.cat((embeddings[q]['question'], embeddings[q]['image']))
+ elif 'question' in inputs and 'image' not in inputs:
+ e = embeddings[q]['question']
+ elif 'question' not in inputs and 'image' in inputs:
+ e = embeddings[q]['image']
+ self.embeddings.append(e)
+
+ def __getitem__(self, index):
+ e = self.embeddings[index]
+ a = self.answers[index]
+ if self.objective == 'classifier':
+ a = torch.sum(F.one_hot(torch.tensor(a), num_classes=self.vocab_len), dim=0)
+ elif self.objective == 'contrastive':
+ a = random.sample(a, 1)[0]
+ return e, a
+
+ def __len__(self):
+ return len(self.embeddings)
+
+
+class AokvqaEmbeddingsDataModule(pl.LightningDataModule):
+
+ def __init__(self, aokvqa_dir, train_features, val_features, objective, backbone, inputs, vocab, vocab_features, batch_size=1, num_workers=0):
+ super().__init__()
+ self.aokvqa_dir = aokvqa_dir
+ self.train_features = train_features
+ self.val_features = val_features
+ self.objective = objective
+ self.backbone = backbone
+ self.inputs = inputs
+ self.vocab = vocab
+ self.vocab_features = vocab_features
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+
+ def setup(self, stage=None):
+ self.train_dataset = AokvqaEmbeddingsDataset(
+ self.aokvqa_dir, 'train', self.train_features, self.objective,
+ self.backbone, self.inputs, self.vocab, self.vocab_features
+ )
+ self.val_dataset = AokvqaEmbeddingsDataset(
+ self.aokvqa_dir, 'val', self.val_features, self.objective,
+ self.backbone, self.inputs, self.vocab, self.vocab_features
+ )
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset, batch_size=self.batch_size, shuffle=True,
+ num_workers=int(0.8 * self.num_workers)
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.val_dataset, batch_size=self.batch_size, shuffle=False,
+ num_workers=int(0.2 * self.num_workers)
+ )
+
+
+class LinearClassifier(pl.LightningModule):
+ def __init__(self, objective, backbone, clip_model_type, inputs, vocab_len, lr=0.001):
+ super().__init__()
+ self.save_hyperparameters(ignore=['lr'])
+ self.lr = lr
+
+ if self.hparams.backbone == 'clip':
+ clip_dim = {
+ 'RN50' : 1024,
+ 'RN50x4' : 640,
+ 'RN50x16' : 768,
+ 'RN50x64' : 1024,
+ 'RN101' : 512,
+ 'ViT-B/32' : 512,
+ 'ViT-B/16' : 512,
+ 'ViT-L/14' : 768,
+ 'ViT-L/14@336px' : 768,
+ }[clip_model_type]
+ emb_dim = clip_dim * len(inputs)
+ elif self.hparams.backbone == 'resnet':
+ emb_dim = 2048
+ elif self.hparams.backbone == 'bert':
+ emb_dim = 768
+
+ if self.hparams.objective == 'classifier':
+ out_dim = vocab_len
+ elif self.hparams.objective == 'contrastive':
+ out_dim = clip_dim
+
+ self.linear = nn.Linear(emb_dim, out_dim)
+
+ def forward(self, x):
+ x = self.linear(x)
+ if self.hparams.objective == 'classifier':
+ x = torch.sigmoid(x)
+ return x
+
+ def compute_loss(self, batch):
+ x, y = batch
+
+ y_pred = self.forward(x)
+
+ if self.hparams.objective == 'classifier':
+ loss = F.binary_cross_entropy(y_pred, y.float())
+ elif self.hparams.objective == 'contrastive':
+ indices = torch.arange(0, x.shape[0], dtype=torch.int64, device=self.device)
+ sim = (y_pred @ y.T).softmax(dim=-1)
+ loss = F.cross_entropy(sim, indices)
+
+ if self.hparams.objective == 'classifier':
+ acc = MF.f1_score(y_pred, y)
+ elif self.hparams.objective == 'contrastive':
+ acc = torch.mean(sim[indices, indices])
+
+ return loss, acc
+
+ def training_step(self, batch, batch_idx):
+ loss, acc = self.compute_loss(batch)
+ self.log("train_loss", loss)
+ self.log("train_acc", acc)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss, acc = self.compute_loss(batch)
+ self.log("val_loss", loss)
+ self.log("val_acc", acc)
+ return loss
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
+ return optimizer
+
+
+if __name__ == '__main__':
+ main()
diff --git a/minigpt4/common/vqa_tools/vqa.py b/minigpt4/common/vqa_tools/vqa.py
new file mode 100644
index 0000000..a386b90
--- /dev/null
+++ b/minigpt4/common/vqa_tools/vqa.py
@@ -0,0 +1,211 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+__author__ = "aagrawal"
+__version__ = "0.9"
+
+# Interface for accessing the VQA dataset.
+
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
+
+# The following functions are defined:
+# VQA - VQA class that loads VQA annotation file and prepares data structures.
+# getQuesIds - Get question ids that satisfy given filter conditions.
+# getImgIds - Get image ids that satisfy given filter conditions.
+# loadQA - Load questions and answers with the specified question ids.
+# showQA - Display the specified questions and answers.
+# loadRes - Load result file and create result object.
+
+# Help on each function can be accessed by: "help(COCO.function)"
+
+import json
+import datetime
+import copy
+
+
+class VQA:
+ def __init__(self, annotation_file=None, question_file=None):
+ """
+ Constructor of VQA helper class for reading and visualizing questions and answers.
+ :param annotation_file (str): location of VQA annotation file
+ :return:
+ """
+ # load dataset
+ self.dataset = {}
+ self.questions = {}
+ self.qa = {}
+ self.qqa = {}
+ self.imgToQA = {}
+ if not annotation_file == None and not question_file == None:
+ print("loading VQA annotations and questions into memory...")
+ time_t = datetime.datetime.utcnow()
+ dataset = json.load(open(annotation_file, "r"))
+ questions = json.load(open(question_file, "r"))
+ self.dataset = dataset
+ self.questions = questions
+ self.createIndex()
+
+ def createIndex(self):
+ # create index
+ print("creating index...")
+ imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]}
+ qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
+ qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]}
+ for ann in self.dataset["annotations"]:
+ imgToQA[ann["image_id"]] += [ann]
+ qa[ann["question_id"]] = ann
+ for ques in self.questions["questions"]:
+ qqa[ques["question_id"]] = ques
+ print("index created!")
+
+ # create class members
+ self.qa = qa
+ self.qqa = qqa
+ self.imgToQA = imgToQA
+
+ def info(self):
+ """
+ Print information about the VQA annotation file.
+ :return:
+ """
+ for key, value in self.datset["info"].items():
+ print("%s: %s" % (key, value))
+
+ def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
+ """
+ Get question ids that satisfy given filter conditions. default skips that filter
+ :param imgIds (int array) : get question ids for given imgs
+ quesTypes (str array) : get question ids for given question types
+ ansTypes (str array) : get question ids for given answer types
+ :return: ids (int array) : integer array of question ids
+ """
+ imgIds = imgIds if type(imgIds) == list else [imgIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset["annotations"]
+ else:
+ if not len(imgIds) == 0:
+ anns = sum(
+ [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],
+ [],
+ )
+ else:
+ anns = self.dataset["annotations"]
+ anns = (
+ anns
+ if len(quesTypes) == 0
+ else [ann for ann in anns if ann["question_type"] in quesTypes]
+ )
+ anns = (
+ anns
+ if len(ansTypes) == 0
+ else [ann for ann in anns if ann["answer_type"] in ansTypes]
+ )
+ ids = [ann["question_id"] for ann in anns]
+ return ids
+
+ def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
+ """
+ Get image ids that satisfy given filter conditions. default skips that filter
+ :param quesIds (int array) : get image ids for given question ids
+ quesTypes (str array) : get image ids for given question types
+ ansTypes (str array) : get image ids for given answer types
+ :return: ids (int array) : integer array of image ids
+ """
+ quesIds = quesIds if type(quesIds) == list else [quesIds]
+ quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
+ ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
+
+ if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
+ anns = self.dataset["annotations"]
+ else:
+ if not len(quesIds) == 0:
+ anns = sum(
+ [self.qa[quesId] for quesId in quesIds if quesId in self.qa], []
+ )
+ else:
+ anns = self.dataset["annotations"]
+ anns = (
+ anns
+ if len(quesTypes) == 0
+ else [ann for ann in anns if ann["question_type"] in quesTypes]
+ )
+ anns = (
+ anns
+ if len(ansTypes) == 0
+ else [ann for ann in anns if ann["answer_type"] in ansTypes]
+ )
+ ids = [ann["image_id"] for ann in anns]
+ return ids
+
+ def loadQA(self, ids=[]):
+ """
+ Load questions and answers with the specified question ids.
+ :param ids (int array) : integer ids specifying question ids
+ :return: qa (object array) : loaded qa objects
+ """
+ if type(ids) == list:
+ return [self.qa[id] for id in ids]
+ elif type(ids) == int:
+ return [self.qa[ids]]
+
+ def showQA(self, anns):
+ """
+ Display the specified annotations.
+ :param anns (array of object): annotations to display
+ :return: None
+ """
+ if len(anns) == 0:
+ return 0
+ for ann in anns:
+ quesId = ann["question_id"]
+ print("Question: %s" % (self.qqa[quesId]["question"]))
+ for ans in ann["answers"]:
+ print("Answer %d: %s" % (ans["answer_id"], ans["answer"]))
+
+ def loadRes(self, resFile, quesFile):
+ """
+ Load result file and return a result object.
+ :param resFile (str) : file name of result file
+ :return: res (obj) : result api object
+ """
+ res = VQA()
+ res.questions = json.load(open(quesFile))
+ res.dataset["info"] = copy.deepcopy(self.questions["info"])
+ res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"])
+ res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"])
+ res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"])
+ res.dataset["license"] = copy.deepcopy(self.questions["license"])
+
+ print("Loading and preparing results... ")
+ time_t = datetime.datetime.utcnow()
+ anns = json.load(open(resFile))
+ assert type(anns) == list, "results is not an array of objects"
+ annsQuesIds = [ann["question_id"] for ann in anns]
+ assert set(annsQuesIds) == set(
+ self.getQuesIds()
+ ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file."
+ for ann in anns:
+ quesId = ann["question_id"]
+ if res.dataset["task_type"] == "Multiple Choice":
+ assert (
+ ann["answer"] in self.qqa[quesId]["multiple_choices"]
+ ), "predicted answer is not one of the multiple choices"
+ qaAnn = self.qa[quesId]
+ ann["image_id"] = qaAnn["image_id"]
+ ann["question_type"] = qaAnn["question_type"]
+ ann["answer_type"] = qaAnn["answer_type"]
+ print(
+ "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds())
+ )
+
+ res.dataset["annotations"] = anns
+ res.createIndex()
+ return res
diff --git a/minigpt4/common/vqa_tools/vqa_eval.py b/minigpt4/common/vqa_tools/vqa_eval.py
new file mode 100644
index 0000000..ee808b3
--- /dev/null
+++ b/minigpt4/common/vqa_tools/vqa_eval.py
@@ -0,0 +1,324 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+# coding=utf-8
+
+__author__ = "aagrawal"
+
+# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
+# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py).
+import sys
+import re
+
+
+class VQAEval:
+ def __init__(self, vqa=None, vqaRes=None, n=2):
+ self.n = n
+ self.accuracy = {}
+ self.evalQA = {}
+ self.evalQuesType = {}
+ self.evalAnsType = {}
+ self.vqa = vqa
+ self.vqaRes = vqaRes
+ if vqa is not None:
+ self.params = {"question_id": vqa.getQuesIds()}
+ self.contractions = {
+ "aint": "ain't",
+ "arent": "aren't",
+ "cant": "can't",
+ "couldve": "could've",
+ "couldnt": "couldn't",
+ "couldn'tve": "couldn't've",
+ "couldnt've": "couldn't've",
+ "didnt": "didn't",
+ "doesnt": "doesn't",
+ "dont": "don't",
+ "hadnt": "hadn't",
+ "hadnt've": "hadn't've",
+ "hadn'tve": "hadn't've",
+ "hasnt": "hasn't",
+ "havent": "haven't",
+ "hed": "he'd",
+ "hed've": "he'd've",
+ "he'dve": "he'd've",
+ "hes": "he's",
+ "howd": "how'd",
+ "howll": "how'll",
+ "hows": "how's",
+ "Id've": "I'd've",
+ "I'dve": "I'd've",
+ "Im": "I'm",
+ "Ive": "I've",
+ "isnt": "isn't",
+ "itd": "it'd",
+ "itd've": "it'd've",
+ "it'dve": "it'd've",
+ "itll": "it'll",
+ "let's": "let's",
+ "maam": "ma'am",
+ "mightnt": "mightn't",
+ "mightnt've": "mightn't've",
+ "mightn'tve": "mightn't've",
+ "mightve": "might've",
+ "mustnt": "mustn't",
+ "mustve": "must've",
+ "neednt": "needn't",
+ "notve": "not've",
+ "oclock": "o'clock",
+ "oughtnt": "oughtn't",
+ "ow's'at": "'ow's'at",
+ "'ows'at": "'ow's'at",
+ "'ow'sat": "'ow's'at",
+ "shant": "shan't",
+ "shed've": "she'd've",
+ "she'dve": "she'd've",
+ "she's": "she's",
+ "shouldve": "should've",
+ "shouldnt": "shouldn't",
+ "shouldnt've": "shouldn't've",
+ "shouldn'tve": "shouldn't've",
+ "somebody'd": "somebodyd",
+ "somebodyd've": "somebody'd've",
+ "somebody'dve": "somebody'd've",
+ "somebodyll": "somebody'll",
+ "somebodys": "somebody's",
+ "someoned": "someone'd",
+ "someoned've": "someone'd've",
+ "someone'dve": "someone'd've",
+ "someonell": "someone'll",
+ "someones": "someone's",
+ "somethingd": "something'd",
+ "somethingd've": "something'd've",
+ "something'dve": "something'd've",
+ "somethingll": "something'll",
+ "thats": "that's",
+ "thered": "there'd",
+ "thered've": "there'd've",
+ "there'dve": "there'd've",
+ "therere": "there're",
+ "theres": "there's",
+ "theyd": "they'd",
+ "theyd've": "they'd've",
+ "they'dve": "they'd've",
+ "theyll": "they'll",
+ "theyre": "they're",
+ "theyve": "they've",
+ "twas": "'twas",
+ "wasnt": "wasn't",
+ "wed've": "we'd've",
+ "we'dve": "we'd've",
+ "weve": "we've",
+ "werent": "weren't",
+ "whatll": "what'll",
+ "whatre": "what're",
+ "whats": "what's",
+ "whatve": "what've",
+ "whens": "when's",
+ "whered": "where'd",
+ "wheres": "where's",
+ "whereve": "where've",
+ "whod": "who'd",
+ "whod've": "who'd've",
+ "who'dve": "who'd've",
+ "wholl": "who'll",
+ "whos": "who's",
+ "whove": "who've",
+ "whyll": "why'll",
+ "whyre": "why're",
+ "whys": "why's",
+ "wont": "won't",
+ "wouldve": "would've",
+ "wouldnt": "wouldn't",
+ "wouldnt've": "wouldn't've",
+ "wouldn'tve": "wouldn't've",
+ "yall": "y'all",
+ "yall'll": "y'all'll",
+ "y'allll": "y'all'll",
+ "yall'd've": "y'all'd've",
+ "y'alld've": "y'all'd've",
+ "y'all'dve": "y'all'd've",
+ "youd": "you'd",
+ "youd've": "you'd've",
+ "you'dve": "you'd've",
+ "youll": "you'll",
+ "youre": "you're",
+ "youve": "you've",
+ }
+ self.manualMap = {
+ "none": "0",
+ "zero": "0",
+ "one": "1",
+ "two": "2",
+ "three": "3",
+ "four": "4",
+ "five": "5",
+ "six": "6",
+ "seven": "7",
+ "eight": "8",
+ "nine": "9",
+ "ten": "10",
+ }
+ self.articles = ["a", "an", "the"]
+
+ self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)")
+ self.commaStrip = re.compile("(\d)(,)(\d)")
+ self.punct = [
+ ";",
+ r"/",
+ "[",
+ "]",
+ '"',
+ "{",
+ "}",
+ "(",
+ ")",
+ "=",
+ "+",
+ "\\",
+ "_",
+ "-",
+ ">",
+ "<",
+ "@",
+ "`",
+ ",",
+ "?",
+ "!",
+ ]
+
+ def evaluate(self, quesIds=None):
+ if quesIds == None:
+ quesIds = [quesId for quesId in self.params["question_id"]]
+ gts = {}
+ res = {}
+ for quesId in quesIds:
+ gts[quesId] = self.vqa.qa[quesId]
+ res[quesId] = self.vqaRes.qa[quesId]
+
+ # =================================================
+ # Compute accuracy
+ # =================================================
+ accQA = []
+ accQuesType = {}
+ accAnsType = {}
+ print("computing accuracy")
+ step = 0
+ for quesId in quesIds:
+ resAns = res[quesId]["answer"]
+ resAns = resAns.replace("\n", " ")
+ resAns = resAns.replace("\t", " ")
+ resAns = resAns.strip()
+ resAns = self.processPunctuation(resAns)
+ resAns = self.processDigitArticle(resAns)
+ gtAcc = []
+ gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]]
+ if len(set(gtAnswers)) > 1:
+ for ansDic in gts[quesId]["answers"]:
+ ansDic["answer"] = self.processPunctuation(ansDic["answer"])
+ for gtAnsDatum in gts[quesId]["answers"]:
+ otherGTAns = [
+ item for item in gts[quesId]["answers"] if item != gtAnsDatum
+ ]
+ matchingAns = [item for item in otherGTAns if item["answer"] == resAns]
+ acc = min(1, float(len(matchingAns)) / 3)
+ gtAcc.append(acc)
+ quesType = gts[quesId]["question_type"]
+ ansType = gts[quesId]["answer_type"]
+ avgGTAcc = float(sum(gtAcc)) / len(gtAcc)
+ accQA.append(avgGTAcc)
+ if quesType not in accQuesType:
+ accQuesType[quesType] = []
+ accQuesType[quesType].append(avgGTAcc)
+ if ansType not in accAnsType:
+ accAnsType[ansType] = []
+ accAnsType[ansType].append(avgGTAcc)
+ self.setEvalQA(quesId, avgGTAcc)
+ self.setEvalQuesType(quesId, quesType, avgGTAcc)
+ self.setEvalAnsType(quesId, ansType, avgGTAcc)
+ if step % 100 == 0:
+ self.updateProgress(step / float(len(quesIds)))
+ step = step + 1
+
+ self.setAccuracy(accQA, accQuesType, accAnsType)
+ print("Done computing accuracy")
+
+ def processPunctuation(self, inText):
+ outText = inText
+ for p in self.punct:
+ if (p + " " in inText or " " + p in inText) or (
+ re.search(self.commaStrip, inText) != None
+ ):
+ outText = outText.replace(p, "")
+ else:
+ outText = outText.replace(p, " ")
+ outText = self.periodStrip.sub("", outText, re.UNICODE)
+ return outText
+
+ def processDigitArticle(self, inText):
+ outText = []
+ tempText = inText.lower().split()
+ for word in tempText:
+ word = self.manualMap.setdefault(word, word)
+ if word not in self.articles:
+ outText.append(word)
+ else:
+ pass
+ for wordId, word in enumerate(outText):
+ if word in self.contractions:
+ outText[wordId] = self.contractions[word]
+ outText = " ".join(outText)
+ return outText
+
+ def setAccuracy(self, accQA, accQuesType, accAnsType):
+ self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n)
+ self.accuracy["perQuestionType"] = {
+ quesType: round(
+ 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]),
+ self.n,
+ )
+ for quesType in accQuesType
+ }
+ self.accuracy["perAnswerType"] = {
+ ansType: round(
+ 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n
+ )
+ for ansType in accAnsType
+ }
+
+ def setEvalQA(self, quesId, acc):
+ self.evalQA[quesId] = round(100 * acc, self.n)
+
+ def setEvalQuesType(self, quesId, quesType, acc):
+ if quesType not in self.evalQuesType:
+ self.evalQuesType[quesType] = {}
+ self.evalQuesType[quesType][quesId] = round(100 * acc, self.n)
+
+ def setEvalAnsType(self, quesId, ansType, acc):
+ if ansType not in self.evalAnsType:
+ self.evalAnsType[ansType] = {}
+ self.evalAnsType[ansType][quesId] = round(100 * acc, self.n)
+
+ def updateProgress(self, progress):
+ barLength = 20
+ status = ""
+ if isinstance(progress, int):
+ progress = float(progress)
+ if not isinstance(progress, float):
+ progress = 0
+ status = "error: progress var must be float\r\n"
+ if progress < 0:
+ progress = 0
+ status = "Halt...\r\n"
+ if progress >= 1:
+ progress = 1
+ status = "Done...\r\n"
+ block = int(round(barLength * progress))
+ text = "\rFinshed Percent: [{0}] {1}% {2}".format(
+ "#" * block + "-" * (barLength - block), int(progress * 100), status
+ )
+ sys.stdout.write(text)
+ sys.stdout.flush()