mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-04 18:10:47 +00:00
80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
import os
|
|
import random
|
|
import json
|
|
from tqdm import tqdm
|
|
import argparse
|
|
import pathlib
|
|
|
|
import openai
|
|
openai.organization = os.getenv('OPENAI_ORG')
|
|
openai.api_key = os.getenv('OPENAI_API_KEY')
|
|
|
|
from load_aokvqa import load_aokvqa
|
|
|
|
|
|
random.seed(0)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--aokvqa-dir', type=pathlib.Path, required=True, dest='aokvqa_dir')
|
|
parser.add_argument('--split', type=str, choices=['train', 'val', 'test'], required=True)
|
|
parser.add_argument('--n', type=int, default=10, dest='num_examples')
|
|
parser.add_argument('--train-context', type=argparse.FileType('r'), dest='train_context_file')
|
|
parser.add_argument('--prefix', type=str, default='', dest='prompt_prefix')
|
|
parser.add_argument('--include-choices', action='store_true', dest='include_choices')
|
|
parser.add_argument('--context', type=argparse.FileType('r'), dest='context_file')
|
|
parser.add_argument('--out', type=argparse.FileType('w'), required=True, dest='output_file')
|
|
args = parser.parse_args()
|
|
|
|
|
|
train_set = load_aokvqa(args.aokvqa_dir, 'train')
|
|
eval_set = load_aokvqa(args.aokvqa_dir, args.split)
|
|
|
|
train_context = {}
|
|
context = {}
|
|
if args.context_file is not None:
|
|
train_context = json.load(args.train_context_file)
|
|
context = json.load(args.context_file)
|
|
|
|
predictions = {}
|
|
|
|
for d in tqdm(eval_set):
|
|
q = d['question_id']
|
|
|
|
prompt = args.prompt_prefix
|
|
for e in random.sample(train_set, args.num_examples):
|
|
prompt += prompt_element(e,
|
|
context=train_context.get(q, None),
|
|
include_choices=args.include_choices,
|
|
answer=True
|
|
)
|
|
prompt += '\n\n'
|
|
|
|
prompt += prompt_element(d,
|
|
context=context.get(q, None),
|
|
include_choices=args.include_choices,
|
|
answer=False
|
|
)
|
|
|
|
response = openai.Completion.create(
|
|
engine="text-curie-001",
|
|
prompt=prompt,
|
|
temperature=0.0,
|
|
max_tokens=10,
|
|
)
|
|
|
|
predictions[q] = response.choices[0].text.strip()
|
|
|
|
json.dump(predictions, args.output_file)
|
|
|
|
|
|
def prompt_element(d, context=None, include_choices=False, answer=False):
|
|
return (f"Context: {context}\n" if context is not None else '') + \
|
|
f"Q: {d['question']}\n" + \
|
|
(f"Choices: {', '.join(d['choices'])}.\n" if include_choices else '') + \
|
|
f"A:" + (f" {d['choices'][d['correct_choice_idx']]}" if answer else '')
|
|
|
|
if __name__ == '__main__':
|
|
main()
|