Files
Speckle-Scanner/04_Rectification/rectificationclasses/rectification.py
T

333 lines
13 KiB
Python

from pathlib import Path
import re
import shutil
from typing import Dict, List, Optional, Tuple
import cv2
import numpy as np
from tqdm import tqdm
VALID_EXTS = {".bmp", ".png", ".jpg", ".jpeg"}
class Rectification:
"""Batch rectification for one project/date tree.
Reads source scans from RAW data tree, copies scans into processing tree, and
rectifies lc-rc/lc-rg/lc-ir pairs with pair-specific calibration params.
"""
def __init__(
self,
source_date_root: str,
calib_params_dir: str,
processing_date_root: str,
pairs: Tuple[str, ...] = ("lc-rc", "lc-rg", "lc-ir"),
keep_lc_from_pair: str = "lc-rc",
session_filter: Optional[str] = None,
) -> None:
self.source_date_root = Path(source_date_root)
self.calib_params_dir = Path(calib_params_dir)
self.processing_date_root = Path(processing_date_root)
self.pairs = pairs
self.keep_lc_from_pair = keep_lc_from_pair
self.session_filter = session_filter
if not self.source_date_root.is_dir():
raise FileNotFoundError(f"Source date root not found: {self.source_date_root}")
if not self.calib_params_dir.is_dir():
raise FileNotFoundError(f"Calibration params dir not found: {self.calib_params_dir}")
self.processing_date_root.mkdir(parents=True, exist_ok=True)
self._params_by_pair: Dict[str, Dict[str, np.ndarray]] = {}
self._rect_maps_cache: Dict[Tuple[str, int, int], Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray], np.ndarray]] = {}
self._load_all_pair_params()
@staticmethod
def _extract_ts_key(filename: str) -> Optional[str]:
stem = Path(filename).stem.lower()
m = re.search(r"_ts(\d+)", stem)
return m.group(1) if m else None
@staticmethod
def _extract_scan_key(filename: str) -> Optional[str]:
stem = Path(filename).stem.lower()
m = re.search(r"(scan\d{6})", stem)
if m:
return m.group(1)
m = re.match(r"^ir_scan_(\d+)", stem)
if m:
return f"scan{int(m.group(1)):06d}"
m = re.match(r"^ir_(\d{6})(?:_|$)", stem)
if m:
return f"scan{m.group(1)}"
return None
@staticmethod
def _extract_generic_suffix_key(filename: str, prefix: str) -> Optional[str]:
stem = Path(filename).stem.lower()
if not stem.startswith(prefix):
return None
return stem[len(prefix):].lstrip("_-.")
@staticmethod
def _camera_from_pair(pair_name: str) -> str:
return pair_name.split("-", 1)[1]
def _load_pair_params(self, pair_name: str) -> Dict[str, np.ndarray]:
npz_path = self.calib_params_dir / f"{pair_name}_parameters.npz"
if not npz_path.exists():
raise FileNotFoundError(f"Missing params file for {pair_name}: {npz_path}")
data = np.load(npz_path, allow_pickle=True)
params = dict(data)
required = [
"L_Intrinsic",
"L_Distortion",
"R_Intrinsic",
"R_Distortion",
"Rotation",
"Translation",
]
missing = [k for k in required if k not in params]
if missing:
raise KeyError(f"{pair_name} params missing keys: {missing}")
return params
def _load_all_pair_params(self) -> None:
for pair_name in self.pairs:
self._params_by_pair[pair_name] = self._load_pair_params(pair_name)
print(f"[INFO] Loaded calibration params for pairs: {', '.join(self.pairs)}")
def _copy_params_link_for_session(self, session_name: str) -> None:
target_params = self.processing_date_root / session_name / "params_link"
target_params.mkdir(parents=True, exist_ok=True)
for src in self.calib_params_dir.iterdir():
if src.is_file() and src.suffix.lower() in (".npz", ".yaml", ".cvstore"):
shutil.copy2(src, target_params / src.name)
@staticmethod
def _copy_raw_images(src_raw_dir: Path, dst_raw_dir: Path) -> None:
dst_raw_dir.mkdir(parents=True, exist_ok=True)
for src in src_raw_dir.iterdir():
if src.is_file():
shutil.copy2(src, dst_raw_dir / src.name)
@staticmethod
def _list_images(raw_dir: Path, prefix: str) -> List[Path]:
imgs = [
p for p in raw_dir.iterdir()
if p.is_file()
and p.suffix.lower() in VALID_EXTS
and p.name.lower().startswith(prefix.lower())
]
imgs.sort()
return imgs
def _pair_images(self, left_images: List[Path], right_images: List[Path], right_camera: str) -> List[Tuple[Path, Path]]:
left_by_ts = {self._extract_ts_key(p.name): p for p in left_images if self._extract_ts_key(p.name)}
right_by_ts = {self._extract_ts_key(p.name): p for p in right_images if self._extract_ts_key(p.name)}
pairs: List[Tuple[Path, Path]] = []
common_ts = sorted(set(left_by_ts.keys()) & set(right_by_ts.keys()))
for ts in common_ts:
pairs.append((left_by_ts[ts], right_by_ts[ts]))
if pairs:
return pairs
left_by_scan = {self._extract_scan_key(p.name): p for p in left_images if self._extract_scan_key(p.name)}
right_by_scan = {self._extract_scan_key(p.name): p for p in right_images if self._extract_scan_key(p.name)}
common_scan = sorted(set(left_by_scan.keys()) & set(right_by_scan.keys()))
for skey in common_scan:
pairs.append((left_by_scan[skey], right_by_scan[skey]))
if pairs:
return pairs
left_by_suffix = {
self._extract_generic_suffix_key(p.name, "lc"): p
for p in left_images
if self._extract_generic_suffix_key(p.name, "lc")
}
right_by_suffix = {
self._extract_generic_suffix_key(p.name, right_camera): p
for p in right_images
if self._extract_generic_suffix_key(p.name, right_camera)
}
common_suffix = sorted(set(left_by_suffix.keys()) & set(right_by_suffix.keys()))
for key in common_suffix:
pairs.append((left_by_suffix[key], right_by_suffix[key]))
if pairs:
return pairs
fallback_count = min(len(left_images), len(right_images))
if fallback_count > 0:
print(
f"[WARN] No key match for lc-{right_camera}; "
f"using index fallback with {fallback_count} pairs."
)
return list(zip(left_images[:fallback_count], right_images[:fallback_count]))
return []
def _get_rectification_maps(
self,
pair_name: str,
left_size: Tuple[int, int],
right_size: Tuple[int, int],
) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray], np.ndarray]:
cache_key = (pair_name, left_size[0], left_size[1])
if cache_key in self._rect_maps_cache:
return self._rect_maps_cache[cache_key]
params = self._params_by_pair[pair_name]
rect_left, rect_right, proj_left, proj_right, q_mat, _, _ = cv2.stereoRectify(
params["L_Intrinsic"],
params["L_Distortion"],
params["R_Intrinsic"],
params["R_Distortion"],
left_size,
params["Rotation"],
params["Translation"],
alpha=1,
flags=0,
)
left_maps = cv2.initUndistortRectifyMap(
params["L_Intrinsic"],
params["L_Distortion"],
rect_left,
proj_left,
left_size,
cv2.CV_32FC1,
)
right_maps = cv2.initUndistortRectifyMap(
params["R_Intrinsic"],
params["R_Distortion"],
rect_right,
proj_right,
right_size,
cv2.CV_32FC1,
)
self._rect_maps_cache[cache_key] = (left_maps, right_maps, q_mat)
return left_maps, right_maps, q_mat
def _rectify_pair_image(
self,
pair_name: str,
left_img: np.ndarray,
right_img: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
left_size = (left_img.shape[1], left_img.shape[0])
right_size = (right_img.shape[1], right_img.shape[0])
left_maps, right_maps, _ = self._get_rectification_maps(pair_name, left_size, right_size)
left_rect = cv2.remap(left_img, left_maps[0], left_maps[1], cv2.INTER_AREA)
right_rect = cv2.remap(right_img, right_maps[0], right_maps[1], cv2.INTER_AREA)
return left_rect, right_rect
def _process_scan(self, session_name: str, scan_name: str) -> Dict[str, int]:
src_raw_dir = self.source_date_root / session_name / scan_name / "01_raw_images"
dst_scan_dir = self.processing_date_root / session_name / scan_name
dst_raw_dir = dst_scan_dir / "01_raw_images"
dst_rect_dir = dst_scan_dir / "02_rect_images"
dst_rect_dir.mkdir(parents=True, exist_ok=True)
self._copy_raw_images(src_raw_dir, dst_raw_dir)
stats = {"pairs_total": 0, "saved": 0, "skipped": 0}
lc_written = False
ordered_pairs = list(self.pairs)
if self.keep_lc_from_pair in ordered_pairs:
ordered_pairs.remove(self.keep_lc_from_pair)
ordered_pairs.insert(0, self.keep_lc_from_pair)
for pair_name in ordered_pairs:
right_camera = self._camera_from_pair(pair_name)
left_images = self._list_images(dst_raw_dir, "lc")
right_images = self._list_images(dst_raw_dir, right_camera)
if not left_images or not right_images:
stats["skipped"] += 1
print(
f"[WARN] {session_name}/{scan_name} {pair_name}: "
f"missing images (lc={len(left_images)}, {right_camera}={len(right_images)})."
)
continue
pairs = self._pair_images(left_images, right_images, right_camera)
if not pairs:
stats["skipped"] += 1
print(f"[WARN] {session_name}/{scan_name} {pair_name}: no valid pairs.")
continue
save_lc_this_pair = (
pair_name == self.keep_lc_from_pair
or (not lc_written and pair_name != self.keep_lc_from_pair)
)
stats["pairs_total"] += len(pairs)
for left_path, right_path in tqdm(
pairs,
desc=f"{session_name}/{scan_name} {pair_name}",
unit="pair",
leave=False,
):
left_img = cv2.imread(str(left_path), cv2.IMREAD_COLOR)
right_img = cv2.imread(str(right_path), cv2.IMREAD_COLOR)
if left_img is None or right_img is None:
stats["skipped"] += 1
continue
left_rect, right_rect = self._rectify_pair_image(pair_name, left_img, right_img)
if save_lc_this_pair:
left_out = dst_rect_dir / left_path.name
cv2.imwrite(str(left_out), left_rect)
lc_written = True
right_out = dst_rect_dir / right_path.name
cv2.imwrite(str(right_out), right_rect)
stats["saved"] += 1
return stats
def _discover_session_scan_raw_dirs(self) -> List[Tuple[str, str]]:
found: List[Tuple[str, str]] = []
session_dirs = sorted(
[p for p in self.source_date_root.iterdir() if p.is_dir() and p.name.lower().startswith("session")]
)
for session_dir in session_dirs:
if self.session_filter and session_dir.name != self.session_filter:
continue
scan_dirs = sorted(
[p for p in session_dir.iterdir() if p.is_dir() and p.name.lower().startswith("scan")]
)
for scan_dir in scan_dirs:
raw_dir = scan_dir / "01_raw_images"
if raw_dir.is_dir():
found.append((session_dir.name, scan_dir.name))
return found
def run_batch(self) -> Dict[str, int]:
all_scans = self._discover_session_scan_raw_dirs()
if not all_scans:
raise RuntimeError(f"No scan folders found under {self.source_date_root}")
print(f"[INFO] Found {len(all_scans)} scans under {self.source_date_root}")
totals = {"scans": 0, "pairs_total": 0, "saved": 0, "skipped": 0}
sessions_seen = set()
for session_name, scan_name in all_scans:
if session_name not in sessions_seen:
self._copy_params_link_for_session(session_name)
sessions_seen.add(session_name)
scan_stats = self._process_scan(session_name, scan_name)
totals["scans"] += 1
totals["pairs_total"] += scan_stats["pairs_total"]
totals["saved"] += scan_stats["saved"]
totals["skipped"] += scan_stats["skipped"]
print(
"[INFO] Batch rectification finished: "
f"scans={totals['scans']} pairs={totals['pairs_total']} "
f"saved={totals['saved']} skipped={totals['skipped']}"
)
return totals