MiniGPT-4/data/create_textvqa_dataset.ipynb
2025-01-20 07:23:11 +07:00

182 lines
4.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import json\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"BASE_DIR = Path(\"./TextVQA\")\n",
"NEW_DIR = Path(\"./TextVQA_tiny\")\n",
"NEW_DIR.mkdir(exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"BASE_IMG_DIR = BASE_DIR / \"train_images\"\n",
"NEW_IMG_DIR = NEW_DIR / \"images\"\n",
"NEW_IMG_DIR.mkdir(exist_ok=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. READ and FILTER the answer for each question"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"from typing import List\n",
"from collections import Counter\n",
"\n",
"\n",
"def filter_answers(answers: List[str]) -> str:\n",
" \"\"\"\n",
" Filter out answers that are most common in a list of answers\n",
" :param answers: List of answers\n",
" :return: Most common answer\n",
" \"\"\"\n",
" frequency = Counter(answers)\n",
" max_freq = max(frequency.values())\n",
" tie_terms = [term for term, count in frequency.items() if count == max_freq]\n",
" if len(tie_terms) == 0:\n",
" return None\n",
" if \"unanswerable\" in tie_terms:\n",
" if len(tie_terms) > 1:\n",
" tie_terms.remove(\"unanswerable\")\n",
" return tie_terms[0]\n",
" else:\n",
" return None\n",
" else:\n",
" random_term = random.choice(tie_terms)\n",
" return random_term"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(34109, 4922)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"target_fields = [\"question\", \"image_id\", \"image_width\", \"image_height\", \"answers\"]\n",
"\n",
"with open(BASE_DIR / \"TextVQA_0.5.1_train.json\", \"r\") as f:\n",
" train_set = json.load(f)[\"data\"]\n",
" train_set = [{k: v for k, v in d.items() if k in target_fields} for d in train_set]\n",
" for d in train_set:\n",
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
" # drop unanswerable questions\n",
" train_set = [d for d in train_set if d[\"answer\"] is not None]\n",
" \n",
"\n",
"with open(BASE_DIR / \"TextVQA_0.5.1_val.json\", \"r\") as f:\n",
" val_set = json.load(f)[\"data\"]\n",
" val_set = [{k: v for k, v in d.items() if k in target_fields} for d in val_set]\n",
" for d in val_set:\n",
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
" # drop unanswerable questions\n",
" val_set = [d for d in val_set if d[\"answer\"] is not None]\n",
" \n",
"\n",
"len(train_set), len(val_set)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"random.seed(42)\n",
"random.shuffle(train_set)\n",
"random.shuffle(val_set)\n",
"\n",
"train_set = train_set[:4000]\n",
"val_set = val_set[:1000]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"for d in train_set:\n",
" img_id = d[\"image_id\"]\n",
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")\n",
"\n",
"for d in val_set:\n",
" img_id = d[\"image_id\"]\n",
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"with open(NEW_DIR / \"train.json\", \"w\") as f:\n",
" json.dump(train_set, f, indent=4)\n",
"\n",
"with open(NEW_DIR / \"val.json\", \"w\") as f:\n",
" json.dump(val_set, f, indent=4)\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "minigptv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.21"
}
},
"nbformat": 4,
"nbformat_minor": 2
}