Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ on:
- main

jobs:
sdk-test:
lint:
runs-on: ubuntu-latest

steps:
Expand All @@ -22,10 +22,6 @@ jobs:
cache: pip
python-version: "3.10.11"

- name: Install dependencies
run: |
pip install -r requirements.txt

- name: Install Python tools
run: |
pip install black==22.10.0 flake8==5.0.4 isort==5.11.5
Expand All @@ -37,4 +33,30 @@ jobs:
run: flake8 .

- name: Run isort
run: isort --check .
run: isort --check .

test:
runs-on: ubuntu-latest

strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]

Comment on lines +44 to +45
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3.13/3.14 は Pillow のバージョンが低すぎてインストールできずにエラーとなるため、このPRには含めない

steps:
- name: Checkout
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5
with:
cache: pip
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
pip install -r requirements.txt
pip install -e ".[dev]"

- name: Run pytest
run: pytest
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ dependencies = [
"geojson>=2.0.0,<4.0",
"xmltodict==0.12.0",
"Pillow>=10.0.0,<11.0.0",
"opencv-python>=4.0.0,<5.0.0",
"aiohttp>=3.8.5"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aiohttp は使わなくなっていたので削除しています。
aiohttp の非同期処理は colab で使う際にエラーがでるため、過去に使わなくなったと思われます。

"opencv-python>=4.10.0,<5.0.0"
]

dynamic = ["version"]

[project.optional-dependencies]
robotics = ["pandas>=2.0.0", "pyarrow>=14.0.0"]
dev = ["pytest>=7.0.0"]

[tool.setuptools]
include-package-data = true
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ numpy>=1.26.0,<2.0.0
geojson>=2.0.0,<4.0
xmltodict==0.12.0
Pillow>=10.0.0,<11.0.0
opencv-python>=4.0.0,<5.0.0
aiohttp>=3.8.5
opencv-python>=4.10.0,<5.0.0
50 changes: 50 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path

import cv2
import numpy as np
import pytest


def _write_synthetic_video(
path: Path,
num_frames: int = 10,
width: int = 64,
height: int = 48,
fps: int = 10,
fourcc_code: str = "mp4v",
) -> Path:
fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
writer = cv2.VideoWriter(str(path), fourcc, fps, (width, height))
if not writer.isOpened():
pytest.skip(f"cv2.VideoWriter could not open {path} with codec {fourcc_code}")
try:
for i in range(num_frames):
frame = np.full((height, width, 3), i * 20 % 255, dtype=np.uint8)
writer.write(frame)
finally:
writer.release()
if not path.exists() or path.stat().st_size == 0:
pytest.skip("Synthetic video could not be created (codec unavailable).")
return path


@pytest.fixture
def synthetic_video(tmp_path):
def _factory(
name: str = "video.mp4",
num_frames: int = 10,
width: int = 64,
height: int = 48,
fps: int = 10,
fourcc_code: str = "mp4v",
) -> Path:
return _write_synthetic_video(
tmp_path / name,
num_frames=num_frames,
width=width,
height=height,
fps=fps,
fourcc_code=fourcc_code,
)

return _factory
115 changes: 115 additions & 0 deletions tests/test_converters_video.py
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

converters.py の opencv 関連のロジックのテスト

Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os

import cv2
import pytest

from fastlabel import converters
from fastlabel.exceptions import FastLabelInvalidException


class TestVideoCapture:
def test_yields_open_capture_and_releases(self, synthetic_video):
video_path = synthetic_video(name="sample.mp4", num_frames=5)

with converters.VideoCapture(str(video_path)) as cap:
assert cap.isOpened()
ret, frame = cap.read()
assert ret is True
assert frame is not None

def test_releases_capture_on_exit(self, synthetic_video):
video_path = synthetic_video(name="sample.mp4")

with converters.VideoCapture(str(video_path)) as cap:
captured = cap

# After release, reading should fail or return falsy ret
ret, _ = captured.read()
assert ret is False

def test_releases_capture_on_exception(self, synthetic_video):
video_path = synthetic_video(name="sample.mp4")

captured = None
with pytest.raises(RuntimeError):
with converters.VideoCapture(str(video_path)) as cap:
captured = cap
raise RuntimeError("boom")

ret, _ = captured.read()
assert ret is False


class TestExportImageFilesForVideoFile:
def test_writes_one_jpg_per_frame(self, synthetic_video, tmp_path):
num_frames = 7
video_path = synthetic_video(name="sample.mp4", num_frames=num_frames)
output_dir = tmp_path / "frames"

names = converters._export_image_files_for_video_file(
file_path=str(video_path),
output_dir_path=str(output_dir),
basename="sample",
)

assert len(names) == num_frames
for name in names:
assert name.endswith(".jpg")
assert name.startswith("sample_")
assert (output_dir / name).is_file()

def test_zero_padding_matches_total_frame_digits(self, synthetic_video, tmp_path):
num_frames = 12
video_path = synthetic_video(name="sample.mp4", num_frames=num_frames)
output_dir = tmp_path / "frames"

names = converters._export_image_files_for_video_file(
file_path=str(video_path),
output_dir_path=str(output_dir),
basename="vid",
)

# 12 frames -> 2 digit zero padding ("00".."11")
assert names[0] == "vid_00.jpg"
assert names[-1] == f"vid_{num_frames - 1:02d}.jpg"

def test_written_frames_are_readable_images(self, synthetic_video, tmp_path):
video_path = synthetic_video(
name="sample.mp4", num_frames=3, width=64, height=48
)
output_dir = tmp_path / "frames"

names = converters._export_image_files_for_video_file(
file_path=str(video_path),
output_dir_path=str(output_dir),
basename="frame",
)

for name in names:
img = cv2.imread(str(output_dir / name))
assert img is not None
assert img.shape == (48, 64, 3)

def test_unopenable_file_raises(self, tmp_path):
bogus = tmp_path / "not_a_video.mp4"
bogus.write_bytes(b"not a real video")

with pytest.raises(FastLabelInvalidException):
converters._export_image_files_for_video_file(
file_path=str(bogus),
output_dir_path=str(tmp_path / "frames"),
basename="x",
)

def test_creates_output_directory_if_missing(self, synthetic_video, tmp_path):
video_path = synthetic_video(name="sample.mp4", num_frames=2)
output_dir = tmp_path / "does" / "not" / "exist"

assert not output_dir.exists()
converters._export_image_files_for_video_file(
file_path=str(video_path),
output_dir_path=str(output_dir),
basename="frame",
)
assert output_dir.is_dir()
assert len(os.listdir(output_dir)) == 2
61 changes: 61 additions & 0 deletions tests/test_lerobot_v3_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import cv2
import pytest

from fastlabel.exceptions import FastLabelInvalidException
from fastlabel.lerobot import v3


class TestExtractVideoSegment:
def test_extracts_requested_number_of_frames(self, synthetic_video, tmp_path):
source = synthetic_video(name="src.mp4", num_frames=20, width=64, height=48)
output = tmp_path / "segment.mp4"

v3._extract_video_segment(
video_path=source,
start_frame=5,
num_frames=8,
output_path=output,
)

assert output.is_file()
cap = cv2.VideoCapture(str(output))
try:
count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
finally:
cap.release()

assert count == 8
assert (width, height) == (64, 48)

def test_stops_when_source_ends(self, synthetic_video, tmp_path):
source = synthetic_video(name="src.mp4", num_frames=10)
output = tmp_path / "segment.mp4"

v3._extract_video_segment(
video_path=source,
start_frame=8,
num_frames=50,
output_path=output,
)

cap = cv2.VideoCapture(str(output))
try:
count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
finally:
cap.release()

assert count == 2

def test_unopenable_file_raises(self, tmp_path):
bogus = tmp_path / "not_a_video.mp4"
bogus.write_bytes(b"garbage")

with pytest.raises(FastLabelInvalidException):
v3._extract_video_segment(
video_path=bogus,
start_frame=0,
num_frames=1,
output_path=tmp_path / "out.mp4",
)
Loading
Loading