Skip to content

Commit 7d6d81f

Browse files
JoeLeelyfkennymckormickmzr1996
authored
[Benchmark] Support SArena Benchmark (#1371)
* add-support-for-SArena * Merge SArena_MINI into SArena class and update implementation. * Fix lint --------- Co-authored-by: Haodong Duan <dhd@pku.edu.cn> Co-authored-by: mzr1996 <mzr1996@163.com>
1 parent 655e65f commit 7d6d81f

23 files changed

Lines changed: 2979 additions & 692 deletions

vlmeval/dataset/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@
139139
from .gsm8k_v import GSM8KVDataset
140140
from .macbench import MaCBench
141141
from .mmesci import MMESCIDataset
142-
from .sarena_mini import SArena_MINI
142+
from .sarena import SArena
143143
from .uni_svg import UniSVG
144144
from .vladbench import VLADBench
145145
from .design2code import Design2Code
@@ -282,7 +282,7 @@ def evaluate(self, eval_file, **judge_kwargs):
282282
MedqbenchPairedDescriptionDataset, MedqbenchCaptionDataset, ChartMuseum, ChartQAPro, ReasonMap_Plus,
283283
olmOCRBench, OceanOCRBench, MATBench, VLRMBench, RefCOCODataset, RefSpatialDataset,
284284
ERQADataset, SimpleVQA, HiPhODataset, MaCBench,
285-
UniSVG, SArena_MINI, VLMsAreBiased, MMESCIDataset, CoreCognition, GroundingME,
285+
UniSVG, SArena, VLMsAreBiased, MMESCIDataset, CoreCognition, GroundingME,
286286
FoxBench, VTCBench, Asclepius, PlotQA, ChartX, ChartBench, ChartCapDataset, WorldVQA, PuzzleVQA, VisualPuzzles,
287287
Design2Code, VLADBench, SSIBenchDataset, NPMM, SGI_Bench_Experimental_Reasoning
288288
]
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
import ast
21
from .image_base import ImageBaseDataset
2+
from .utils.sarena import evaluate_sarena
33
from ..smp import *
4-
from .utils.sarena_mini import evaluate_sarena_mini
54

65

7-
class SArena_MINI(ImageBaseDataset):
6+
class SArena(ImageBaseDataset):
87

98
TYPE = "VQA"
109

1110
DATASET_URL = {
11+
"SArena": "https://huggingface.co/datasets/JoeLeelyf/SArena-VLMEvalKit/resolve/main/SArena.tsv",
1212
"SArena_MINI": "https://huggingface.co/datasets/JoeLeelyf/SArena-VLMEvalKit/resolve/main/SArena_MINI.tsv"
1313
}
1414

1515
DATASET_MD5 = {
16+
"SArena": "2a747c13c063a6c9839c66611b61526c",
1617
"SArena_MINI": "c87fa82819a5fce652df40f6332266ff"
1718
}
1819

@@ -44,4 +45,4 @@ def build_prompt(self, line):
4445
return msgs
4546

4647
def evaluate(self, eval_file, **judge_kwargs):
47-
return evaluate_sarena_mini(eval_file)
48+
return evaluate_sarena(eval_file, dataset=self.dataset)

vlmeval/dataset/utils/SArena/CLIP_Score.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
from typing import Literal
2+
13
import torch
2-
from tqdm import tqdm
34
from torch.utils.data import DataLoader
4-
from torchmetrics.multimodal.clip_score import CLIPScore
55
from torchmetrics.functional.multimodal.clip_score import _clip_score_update
6+
from torchmetrics.multimodal.clip_score import CLIPScore
67
from torchvision.transforms import ToTensor
8+
from tqdm import tqdm
9+
710
from .base_metric import BaseMetric
8-
from typing import Literal
911

1012

1113
class CLIPScoreCalculator(BaseMetric):

vlmeval/dataset/utils/SArena/DINO_Score.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,19 @@ def __init__(self):
1616
self.metric = self.calculate_DINOv2_similarity_score
1717

1818
def get_DINOv2_model(self, model_size):
19-
if model_size == "small":
20-
model_size = "facebook/dinov2-small"
21-
elif model_size == "base":
22-
model_size = "facebook/dinov2-base"
23-
elif model_size == "large":
24-
model_size = "facebook/dinov2-large"
25-
else:
19+
model_map = {
20+
"small": "facebook/dinov2-small",
21+
"base": "facebook/dinov2-base",
22+
"large": "facebook/dinov2-large",
23+
}
24+
name = model_map.get(model_size)
25+
if not name:
2626
raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}")
27-
return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size)
27+
28+
model = AutoModel.from_pretrained(name)
29+
processor = AutoImageProcessor.from_pretrained(name)
30+
31+
return model, processor
2832

2933
def process_input(self, image, processor):
3034
if isinstance(image, str):

vlmeval/dataset/utils/SArena/LPIPS.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,42 @@
1+
import os
2+
import shutil
13
import torch
24
import lpips
35

46
from tqdm import tqdm
7+
from vlmeval.smp.file import LMUDataRoot
58
from torch.utils.data import DataLoader
69
from torchvision.transforms import ToTensor, Normalize
710
from .base_metric import BaseMetric
811

912

13+
def get_lpips_vgg_model(device):
14+
"""Load LPIPS VGG model, downloading to aux_models if needed."""
15+
vgg_path = os.path.join(LMUDataRoot(), 'aux_models', 'vgg.pth')
16+
17+
if os.path.exists(vgg_path):
18+
return lpips.LPIPS(net='vgg', model_path=vgg_path).to(device)
19+
20+
# Download model (lpips uses torch hub cache)
21+
model = lpips.LPIPS(net='vgg').to(device)
22+
23+
# Copy from torch hub cache to aux_models for future offline use
24+
aux_models_dir = os.path.dirname(vgg_path)
25+
os.makedirs(aux_models_dir, exist_ok=True)
26+
27+
cache_path = os.path.expanduser('~/.cache/torch/hub/checkpoints/vgg_net_g.pth')
28+
if os.path.exists(cache_path):
29+
shutil.copy(cache_path, vgg_path)
30+
31+
return model
32+
33+
1034
class LPIPSCalculator(BaseMetric):
1135
def __init__(self):
1236
super().__init__()
1337
self.class_name = self.__class__.__name__
1438
self.device = "cuda" if torch.cuda.is_available() else "cpu"
15-
self.model = lpips.LPIPS(net='vgg').to(self.device)
39+
self.model = get_lpips_vgg_model(self.device)
1640
self.metric = self.LPIPS
1741
self.to_tensor = ToTensor()
1842
self.normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

vlmeval/dataset/utils/SArena/inception.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import os
12
import torch
23
import torch.nn as nn
34
import torch.nn.functional as F
45
import torchvision
56

67
from torch.hub import load_state_dict_from_url
8+
from vlmeval.smp.file import LMUDataRoot
79

810
# Inception weights ported to Pytorch from
911
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
@@ -213,7 +215,16 @@ def fid_inception_v3():
213215
inception.Mixed_7b = FIDInceptionE_1(1280)
214216
inception.Mixed_7c = FIDInceptionE_2(2048)
215217

216-
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
218+
local_path = os.path.join(LMUDataRoot(), 'aux_models', 'pt_inception-2015-12-05-6726825d.pth')
219+
if os.path.exists(local_path):
220+
state_dict = torch.load(local_path, map_location='cpu', weights_only=True)
221+
else:
222+
# Ensure directory exists
223+
os.makedirs(os.path.dirname(local_path), exist_ok=True)
224+
# Download to aux_models directory
225+
state_dict = load_state_dict_from_url(
226+
FID_WEIGHTS_URL, progress=True, model_dir=os.path.dirname(local_path)
227+
)
217228
inception.load_state_dict(state_dict)
218229
return inception
219230

vlmeval/dataset/utils/SArena/metrics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from dataclasses import dataclass
23
from typing import Dict, Callable
34

@@ -57,11 +58,11 @@ def calculate_metrics(self, batch):
5758
print(f"Calculating {metric_name}...")
5859
if metric_name in ['FID', 'FID-C']:
5960
avg_result = metric.calculate_score(batch)
60-
if avg_result is not float("nan"):
61+
if not math.isnan(avg_result):
6162
avg_results_dict[metric_name] = avg_result
6263
else:
6364
avg_result, values = metric.calculate_score(batch)
64-
if avg_result is not float("nan"):
65+
if not math.isnan(avg_result):
6566
avg_results_dict[metric_name] = avg_result
6667

6768
return avg_results_dict
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
av
2+
cairosvg
3+
cd-fvd
4+
evaluate==0.4.3
5+
ftfy
6+
hpsv2x
7+
lpips
8+
matplotlib
9+
moviepy
10+
nest_asyncio
11+
pyppeteer
12+
regex
13+
rich==13.9.4
14+
scikit-image
15+
svgpathtools==1.6.1
16+
timm==1.0.15
17+
torchmetrics
18+
vtracer

vlmeval/dataset/utils/SArena/token_length.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
import torch
1+
import os
22

3-
from tqdm import tqdm
3+
import torch
44
from torch.utils.data import DataLoader
5+
from tqdm import tqdm
56
from transformers import AutoTokenizer
67

7-
from .base_metric import BaseMetric
8+
from vlmeval.smp.file import LMUDataRoot
89
from .average_meter import AverageMeter
10+
from .base_metric import BaseMetric
911

1012

1113
class TokenLengthCalculator(BaseMetric):

0 commit comments

Comments
 (0)