{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import json\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "BASE_DIR = Path(\"./TextVQA\")\n", "NEW_DIR = Path(\"./TextVQA_tiny\")\n", "NEW_DIR.mkdir(exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "BASE_IMG_DIR = BASE_DIR / \"train_images\"\n", "NEW_IMG_DIR = NEW_DIR / \"images\"\n", "NEW_IMG_DIR.mkdir(exist_ok=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. READ and FILTER the answer for each question" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "import random\n", "from typing import List\n", "from collections import Counter\n", "\n", "\n", "def filter_answers(answers: List[str]) -> str:\n", " \"\"\"\n", " Filter out answers that are most common in a list of answers\n", " :param answers: List of answers\n", " :return: Most common answer\n", " \"\"\"\n", " frequency = Counter(answers)\n", " max_freq = max(frequency.values())\n", " tie_terms = [term for term, count in frequency.items() if count == max_freq]\n", " if len(tie_terms) == 0:\n", " return None\n", " if \"unanswerable\" in tie_terms:\n", " if len(tie_terms) > 1:\n", " tie_terms.remove(\"unanswerable\")\n", " return tie_terms[0]\n", " else:\n", " return None\n", " else:\n", " random_term = random.choice(tie_terms)\n", " return random_term" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(34109, 4922)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "target_fields = [\"question\", \"image_id\", \"image_width\", \"image_height\", \"answers\"]\n", "\n", "with open(BASE_DIR / \"TextVQA_0.5.1_train.json\", \"r\") as f:\n", " train_set = json.load(f)[\"data\"]\n", " train_set = [{k: v for k, v in d.items() if k in target_fields} for d in train_set]\n", " for d in train_set:\n", " d[\"answer\"] = filter_answers(d[\"answers\"])\n", " # drop unanswerable questions\n", " train_set = [d for d in train_set if d[\"answer\"] is not None]\n", " \n", "\n", "with open(BASE_DIR / \"TextVQA_0.5.1_val.json\", \"r\") as f:\n", " val_set = json.load(f)[\"data\"]\n", " val_set = [{k: v for k, v in d.items() if k in target_fields} for d in val_set]\n", " for d in val_set:\n", " d[\"answer\"] = filter_answers(d[\"answers\"])\n", " # drop unanswerable questions\n", " val_set = [d for d in val_set if d[\"answer\"] is not None]\n", " \n", "\n", "len(train_set), len(val_set)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "random.seed(42)\n", "random.shuffle(train_set)\n", "random.shuffle(val_set)\n", "\n", "train_set = train_set[:4000]\n", "val_set = val_set[:1000]" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "for d in train_set:\n", " img_id = d[\"image_id\"]\n", " os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")\n", "\n", "for d in val_set:\n", " img_id = d[\"image_id\"]\n", " os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "with open(NEW_DIR / \"train.json\", \"w\") as f:\n", " json.dump(train_set, f, indent=4)\n", "\n", "with open(NEW_DIR / \"val.json\", \"w\") as f:\n", " json.dump(val_set, f, indent=4)\n", " " ] } ], "metadata": { "kernelspec": { "display_name": "minigptv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.21" } }, "nbformat": 4, "nbformat_minor": 2 }