mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
185 lines
4.7 KiB
Plaintext
185 lines
4.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import json\n",
|
|
"from pathlib import Path"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BASE_DIR = Path(\"./TextVQA\")\n",
|
|
"NEW_DIR = Path(\"./TextVQA_tiny\")\n",
|
|
"NEW_DIR.mkdir(exist_ok=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"BASE_IMG_DIR = BASE_DIR / \"train_images\"\n",
|
|
"NEW_IMG_DIR = NEW_DIR / \"images\"\n",
|
|
"NEW_IMG_DIR.mkdir(exist_ok=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## 1. READ and FILTER the answer for each question"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import random\n",
|
|
"from typing import List\n",
|
|
"from collections import Counter\n",
|
|
"\n",
|
|
"\n",
|
|
"def filter_answers(answers: List[str]) -> str:\n",
|
|
" \"\"\"\n",
|
|
" Filter out answers that are most common in a list of answers\n",
|
|
" :param answers: List of answers\n",
|
|
" :return: Most common answer\n",
|
|
" \"\"\"\n",
|
|
" frequency = Counter(answers)\n",
|
|
" max_freq = max(frequency.values())\n",
|
|
" tie_terms = [term for term, count in frequency.items() if count == max_freq]\n",
|
|
" if len(tie_terms) == 0:\n",
|
|
" return None\n",
|
|
" if \"unanswerable\" in tie_terms:\n",
|
|
" if len(tie_terms) > 1:\n",
|
|
" tie_terms.remove(\"unanswerable\")\n",
|
|
" return tie_terms[0]\n",
|
|
" else:\n",
|
|
" return None\n",
|
|
" else:\n",
|
|
" random_term = random.choice(tie_terms)\n",
|
|
" return random_term"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(34120, 4922)"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"target_fields = [\"question\", \"image_id\", \"image_width\", \"image_height\", \"answers\"]\n",
|
|
"\n",
|
|
"with open(BASE_DIR / \"TextVQA_0.5.1_train.json\", \"r\") as f:\n",
|
|
" train_set = json.load(f)[\"data\"]\n",
|
|
" train_set = [{k: v for k, v in d.items() if k in target_fields} for d in train_set]\n",
|
|
" for d in train_set:\n",
|
|
" filtered_answer = filter_answers(d[\"answers\"])\n",
|
|
" if filtered_answer is None:\n",
|
|
" train_set.remove(d)\n",
|
|
" else:\n",
|
|
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
|
|
"\n",
|
|
"with open(BASE_DIR / \"TextVQA_0.5.1_val.json\", \"r\") as f:\n",
|
|
" val_set = json.load(f)[\"data\"]\n",
|
|
" val_set = [{k: v for k, v in d.items() if k in target_fields} for d in val_set]\n",
|
|
" for d in val_set:\n",
|
|
" filtered_answer = filter_answers(d[\"answers\"])\n",
|
|
" if filtered_answer is None:\n",
|
|
" val_set.remove(d)\n",
|
|
" else:\n",
|
|
" d[\"answer\"] = filter_answers(d[\"answers\"])\n",
|
|
" \n",
|
|
"\n",
|
|
"len(train_set), len(val_set)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"random.seed(42)\n",
|
|
"random.shuffle(train_set)\n",
|
|
"random.shuffle(val_set)\n",
|
|
"\n",
|
|
"train_set = train_set[:4000]\n",
|
|
"val_set = val_set[:1000]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for d in train_set:\n",
|
|
" img_id = d[\"image_id\"]\n",
|
|
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")\n",
|
|
"\n",
|
|
"for d in val_set:\n",
|
|
" img_id = d[\"image_id\"]\n",
|
|
" os.system(f\"cp {BASE_IMG_DIR}/{img_id}.jpg {NEW_IMG_DIR}/{img_id}.jpg\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"with open(NEW_DIR / \"train.json\", \"w\") as f:\n",
|
|
" json.dump(train_set, f, indent=4)\n",
|
|
"\n",
|
|
"with open(NEW_DIR / \"val.json\", \"w\") as f:\n",
|
|
" json.dump(val_set, f, indent=4)\n",
|
|
" "
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "minigptv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.21"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|