mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
add create textvqa tiny
This commit is contained in:
parent
25918644be
commit
21106c5639
184
data/create_textvqa_dataset.ipynb
Normal file
184
data/create_textvqa_dataset.ipynb
Normal file
@ -0,0 +1,184 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import json\n",
|
||||
"from pathlib import Path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BASE_DIR = Path(\"./TextVQA\")\n",
|
||||
"NEW_DIR = Path(\"./TextVQA_tiny\")\n",
|
||||
"NEW_DIR.mkdir(exist_ok=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"BASE_IMG_DIR = BASE_DIR / \"train_images\"\n",
|
||||
"NEW_IMG_DIR = NEW_DIR / \"images\"\n",
|
||||
"NEW_IMG_DIR.mkdir(exist_ok=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. READ and FILTER the answer for each question"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"from typing import List\n",
|
||||
"from collections import Counter\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def filter_answers(answers: List[str]) -> str:\n",
|
||||
" \"\"\"\n",
|
||||
" Filter out answers that are most common in a list of answers\n",
|
||||
" :param answers: List of answers\n",
|
||||
" :return: Most common answer\n",
|
||||
" \"\"\"\n",
|
||||
" frequency = Counter(answers)\n",
|
||||
" max_freq = max(frequency.values())\n",
|
||||
" tie_terms = [term for term, count in frequency.items() if count == max_freq]\n",
|
||||
" if len(tie_terms) == 0:\n",
|
||||
" return None\n",
|
||||
" if \"unanswerable\" in tie_terms:\n",
|
||||
" if len(tie_terms) > 1:\n",
|
||||
" tie_terms.remove(\"unanswerable\")\n",
|
||||
" return tie_terms[0]\n",
|
||||
" else:\n",
|
||||
" return None\n",
|
||||
" else:\n",
|
||||
" random_term = random.choice(tie_terms)\n",
|
||||
" return random_term"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"(34120, 4922)"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"target_fields = [\"question\", \"image_id\", \"image_width\", \"image_height\", \"answers\"]\n",
|
||||
"\n",
|
||||
"with open(BASE_DIR / \"TextVQA_0.5.1_train.json\", \"r\") as f:\n",
|
||||
" train_set = json.load(f)[\"data\"]\n",
|
||||
" train_set = [{k: v for k, v in d.items() if k in target_fields} for d in train_set]\n",
|
||||
" for d in train_set:\n",
|
||||
" filtered_answer = filter_answers(d[\"answers\"])\n",
|
||||
" if filtered_answer is None:\n",
|
||||
" train_set.remove(d)\n",
|
||||
" else:\n",
|
||||
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
|
||||
"\n",
|
||||
"with open(BASE_DIR / \"TextVQA_0.5.1_val.json\", \"r\") as f:\n",
|
||||
" val_set = json.load(f)[\"data\"]\n",
|
||||
" val_set = [{k: v for k, v in d.items() if k in target_fields} for d in val_set]\n",
|
||||
" for d in val_set:\n",
|
||||
" filtered_answer = filter_answers(d[\"answers\"])\n",
|
||||
" if filtered_answer is None:\n",
|
||||
" val_set.remove(d)\n",
|
||||
" else:\n",
|
||||
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
|
||||
" \n",
|
||||
"\n",
|
||||
"len(train_set), len(val_set)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"random.seed(42)\n",
|
||||
"random.shuffle(train_set)\n",
|
||||
"random.shuffle(val_set)\n",
|
||||
"\n",
|
||||
"train_set = train_set[:4000]\n",
|
||||
"val_set = val_set[:1000]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for d in train_set:\n",
|
||||
" img_id = d[\"image_id\"]\n",
|
||||
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")\n",
|
||||
"\n",
|
||||
"for d in val_set:\n",
|
||||
" img_id = d[\"image_id\"]\n",
|
||||
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(NEW_DIR / \"train.json\", \"w\") as f:\n",
|
||||
" json.dump(train_set, f, indent=4)\n",
|
||||
"\n",
|
||||
"with open(NEW_DIR / \"val.json\", \"w\") as f:\n",
|
||||
" json.dump(val_set, f, indent=4)\n",
|
||||
" "
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "minigptv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.21"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Reference in New Issue
Block a user