diff --git a/data/create_textvqa_dataset.ipynb b/data/create_textvqa_dataset.ipynb new file mode 100644 index 0000000..335e8f6 --- /dev/null +++ b/data/create_textvqa_dataset.ipynb @@ -0,0 +1,184 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "from pathlib import Path" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "BASE_DIR = Path(\"./TextVQA\")\n", + "NEW_DIR = Path(\"./TextVQA_tiny\")\n", + "NEW_DIR.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "BASE_IMG_DIR = BASE_DIR / \"train_images\"\n", + "NEW_IMG_DIR = NEW_DIR / \"images\"\n", + "NEW_IMG_DIR.mkdir(exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. READ and FILTER the answer for each question" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "from typing import List\n", + "from collections import Counter\n", + "\n", + "\n", + "def filter_answers(answers: List[str]) -> str:\n", + " \"\"\"\n", + " Filter out answers that are most common in a list of answers\n", + " :param answers: List of answers\n", + " :return: Most common answer\n", + " \"\"\"\n", + " frequency = Counter(answers)\n", + " max_freq = max(frequency.values())\n", + " tie_terms = [term for term, count in frequency.items() if count == max_freq]\n", + " if len(tie_terms) == 0:\n", + " return None\n", + " if \"unanswerable\" in tie_terms:\n", + " if len(tie_terms) > 1:\n", + " tie_terms.remove(\"unanswerable\")\n", + " return tie_terms[0]\n", + " else:\n", + " return None\n", + " else:\n", + " random_term = random.choice(tie_terms)\n", + " return random_term" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(34120, 4922)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "target_fields = [\"question\", \"image_id\", \"image_width\", \"image_height\", \"answers\"]\n", + "\n", + "with open(BASE_DIR / \"TextVQA_0.5.1_train.json\", \"r\") as f:\n", + " train_set = json.load(f)[\"data\"]\n", + " train_set = [{k: v for k, v in d.items() if k in target_fields} for d in train_set]\n", + " for d in train_set:\n", + " filtered_answer = filter_answers(d[\"answers\"])\n", + " if filtered_answer is None:\n", + " train_set.remove(d)\n", + " else:\n", + " d[\"answer\"] = filter_answers(d[\"answers\"])\n", + "\n", + "with open(BASE_DIR / \"TextVQA_0.5.1_val.json\", \"r\") as f:\n", + " val_set = json.load(f)[\"data\"]\n", + " val_set = [{k: v for k, v in d.items() if k in target_fields} for d in val_set]\n", + " for d in val_set:\n", + " filtered_answer = filter_answers(d[\"answers\"])\n", + " if filtered_answer is None:\n", + " val_set.remove(d)\n", + " else:\n", + " d[\"answer\"] = filter_answers(d[\"answers\"])\n", + " \n", + "\n", + "len(train_set), len(val_set)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "random.seed(42)\n", + "random.shuffle(train_set)\n", + "random.shuffle(val_set)\n", + "\n", + "train_set = train_set[:4000]\n", + "val_set = val_set[:1000]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "for d in train_set:\n", + " img_id = d[\"image_id\"]\n", + " os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")\n", + "\n", + "for d in val_set:\n", + " img_id = d[\"image_id\"]\n", + " os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "with open(NEW_DIR / \"train.json\", \"w\") as f:\n", + " json.dump(train_set, f, indent=4)\n", + "\n", + "with open(NEW_DIR / \"val.json\", \"w\") as f:\n", + " json.dump(val_set, f, indent=4)\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "minigptv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.21" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}