from pathlib import Path import re import shutil from typing import Dict, List, Optional, Tuple import cv2 import numpy as np from tqdm import tqdm VALID_EXTS = {".bmp", ".png", ".jpg", ".jpeg"} class Rectification: """Batch rectification for one project/date tree. Reads source scans from RAW data tree, copies scans into processing tree, and rectifies lc-rc/lc-rg/lc-ir pairs with pair-specific calibration params. """ def __init__( self, source_date_root: str, calib_params_dir: str, processing_date_root: str, pairs: Tuple[str, ...] = ("lc-rc", "lc-rg", "lc-ir"), keep_lc_from_pair: str = "lc-rc", session_filter: Optional[str] = None, ) -> None: self.source_date_root = Path(source_date_root) self.calib_params_dir = Path(calib_params_dir) self.processing_date_root = Path(processing_date_root) self.pairs = pairs self.keep_lc_from_pair = keep_lc_from_pair self.session_filter = session_filter if not self.source_date_root.is_dir(): raise FileNotFoundError(f"Source date root not found: {self.source_date_root}") if not self.calib_params_dir.is_dir(): raise FileNotFoundError(f"Calibration params dir not found: {self.calib_params_dir}") self.processing_date_root.mkdir(parents=True, exist_ok=True) self._params_by_pair: Dict[str, Dict[str, np.ndarray]] = {} self._rect_maps_cache: Dict[Tuple[str, int, int], Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray], np.ndarray]] = {} self._load_all_pair_params() @staticmethod def _extract_ts_key(filename: str) -> Optional[str]: stem = Path(filename).stem.lower() m = re.search(r"_ts(\d+)", stem) return m.group(1) if m else None @staticmethod def _extract_scan_key(filename: str) -> Optional[str]: stem = Path(filename).stem.lower() m = re.search(r"(scan\d{6})", stem) if m: return m.group(1) m = re.match(r"^ir_scan_(\d+)", stem) if m: return f"scan{int(m.group(1)):06d}" m = re.match(r"^ir_(\d{6})(?:_|$)", stem) if m: return f"scan{m.group(1)}" return None @staticmethod def _extract_generic_suffix_key(filename: str, prefix: str) -> Optional[str]: stem = Path(filename).stem.lower() if not stem.startswith(prefix): return None return stem[len(prefix):].lstrip("_-.") @staticmethod def _camera_from_pair(pair_name: str) -> str: return pair_name.split("-", 1)[1] def _load_pair_params(self, pair_name: str) -> Dict[str, np.ndarray]: npz_path = self.calib_params_dir / f"{pair_name}_parameters.npz" if not npz_path.exists(): raise FileNotFoundError(f"Missing params file for {pair_name}: {npz_path}") data = np.load(npz_path, allow_pickle=True) params = dict(data) required = [ "L_Intrinsic", "L_Distortion", "R_Intrinsic", "R_Distortion", "Rotation", "Translation", ] missing = [k for k in required if k not in params] if missing: raise KeyError(f"{pair_name} params missing keys: {missing}") return params def _load_all_pair_params(self) -> None: for pair_name in self.pairs: self._params_by_pair[pair_name] = self._load_pair_params(pair_name) print(f"[INFO] Loaded calibration params for pairs: {', '.join(self.pairs)}") def _copy_params_link_for_session(self, session_name: str) -> None: target_params = self.processing_date_root / session_name / "params_link" target_params.mkdir(parents=True, exist_ok=True) for src in self.calib_params_dir.iterdir(): if src.is_file() and src.suffix.lower() in (".npz", ".yaml", ".cvstore"): shutil.copy2(src, target_params / src.name) @staticmethod def _copy_raw_images(src_raw_dir: Path, dst_raw_dir: Path) -> None: dst_raw_dir.mkdir(parents=True, exist_ok=True) for src in src_raw_dir.iterdir(): if src.is_file(): shutil.copy2(src, dst_raw_dir / src.name) @staticmethod def _list_images(raw_dir: Path, prefix: str) -> List[Path]: imgs = [ p for p in raw_dir.iterdir() if p.is_file() and p.suffix.lower() in VALID_EXTS and p.name.lower().startswith(prefix.lower()) ] imgs.sort() return imgs def _pair_images(self, left_images: List[Path], right_images: List[Path], right_camera: str) -> List[Tuple[Path, Path]]: left_by_ts = {self._extract_ts_key(p.name): p for p in left_images if self._extract_ts_key(p.name)} right_by_ts = {self._extract_ts_key(p.name): p for p in right_images if self._extract_ts_key(p.name)} pairs: List[Tuple[Path, Path]] = [] common_ts = sorted(set(left_by_ts.keys()) & set(right_by_ts.keys())) for ts in common_ts: pairs.append((left_by_ts[ts], right_by_ts[ts])) if pairs: return pairs left_by_scan = {self._extract_scan_key(p.name): p for p in left_images if self._extract_scan_key(p.name)} right_by_scan = {self._extract_scan_key(p.name): p for p in right_images if self._extract_scan_key(p.name)} common_scan = sorted(set(left_by_scan.keys()) & set(right_by_scan.keys())) for skey in common_scan: pairs.append((left_by_scan[skey], right_by_scan[skey])) if pairs: return pairs left_by_suffix = { self._extract_generic_suffix_key(p.name, "lc"): p for p in left_images if self._extract_generic_suffix_key(p.name, "lc") } right_by_suffix = { self._extract_generic_suffix_key(p.name, right_camera): p for p in right_images if self._extract_generic_suffix_key(p.name, right_camera) } common_suffix = sorted(set(left_by_suffix.keys()) & set(right_by_suffix.keys())) for key in common_suffix: pairs.append((left_by_suffix[key], right_by_suffix[key])) if pairs: return pairs fallback_count = min(len(left_images), len(right_images)) if fallback_count > 0: print( f"[WARN] No key match for lc-{right_camera}; " f"using index fallback with {fallback_count} pairs." ) return list(zip(left_images[:fallback_count], right_images[:fallback_count])) return [] def _get_rectification_maps( self, pair_name: str, left_size: Tuple[int, int], right_size: Tuple[int, int], ) -> Tuple[Tuple[np.ndarray, np.ndarray], Tuple[np.ndarray, np.ndarray], np.ndarray]: cache_key = (pair_name, left_size[0], left_size[1]) if cache_key in self._rect_maps_cache: return self._rect_maps_cache[cache_key] params = self._params_by_pair[pair_name] rect_left, rect_right, proj_left, proj_right, q_mat, _, _ = cv2.stereoRectify( params["L_Intrinsic"], params["L_Distortion"], params["R_Intrinsic"], params["R_Distortion"], left_size, params["Rotation"], params["Translation"], alpha=1, flags=0, ) left_maps = cv2.initUndistortRectifyMap( params["L_Intrinsic"], params["L_Distortion"], rect_left, proj_left, left_size, cv2.CV_32FC1, ) right_maps = cv2.initUndistortRectifyMap( params["R_Intrinsic"], params["R_Distortion"], rect_right, proj_right, right_size, cv2.CV_32FC1, ) self._rect_maps_cache[cache_key] = (left_maps, right_maps, q_mat) return left_maps, right_maps, q_mat def _rectify_pair_image( self, pair_name: str, left_img: np.ndarray, right_img: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: left_size = (left_img.shape[1], left_img.shape[0]) right_size = (right_img.shape[1], right_img.shape[0]) left_maps, right_maps, _ = self._get_rectification_maps(pair_name, left_size, right_size) left_rect = cv2.remap(left_img, left_maps[0], left_maps[1], cv2.INTER_AREA) right_rect = cv2.remap(right_img, right_maps[0], right_maps[1], cv2.INTER_AREA) return left_rect, right_rect def _process_scan(self, session_name: str, scan_name: str) -> Dict[str, int]: src_raw_dir = self.source_date_root / session_name / scan_name / "01_raw_images" dst_scan_dir = self.processing_date_root / session_name / scan_name dst_raw_dir = dst_scan_dir / "01_raw_images" dst_rect_dir = dst_scan_dir / "02_rect_images" dst_rect_dir.mkdir(parents=True, exist_ok=True) self._copy_raw_images(src_raw_dir, dst_raw_dir) stats = {"pairs_total": 0, "saved": 0, "skipped": 0} lc_written = False ordered_pairs = list(self.pairs) if self.keep_lc_from_pair in ordered_pairs: ordered_pairs.remove(self.keep_lc_from_pair) ordered_pairs.insert(0, self.keep_lc_from_pair) for pair_name in ordered_pairs: right_camera = self._camera_from_pair(pair_name) left_images = self._list_images(dst_raw_dir, "lc") right_images = self._list_images(dst_raw_dir, right_camera) if not left_images or not right_images: stats["skipped"] += 1 print( f"[WARN] {session_name}/{scan_name} {pair_name}: " f"missing images (lc={len(left_images)}, {right_camera}={len(right_images)})." ) continue pairs = self._pair_images(left_images, right_images, right_camera) if not pairs: stats["skipped"] += 1 print(f"[WARN] {session_name}/{scan_name} {pair_name}: no valid pairs.") continue save_lc_this_pair = ( pair_name == self.keep_lc_from_pair or (not lc_written and pair_name != self.keep_lc_from_pair) ) stats["pairs_total"] += len(pairs) for left_path, right_path in tqdm( pairs, desc=f"{session_name}/{scan_name} {pair_name}", unit="pair", leave=False, ): left_img = cv2.imread(str(left_path), cv2.IMREAD_COLOR) right_img = cv2.imread(str(right_path), cv2.IMREAD_COLOR) if left_img is None or right_img is None: stats["skipped"] += 1 continue left_rect, right_rect = self._rectify_pair_image(pair_name, left_img, right_img) if save_lc_this_pair: left_out = dst_rect_dir / left_path.name cv2.imwrite(str(left_out), left_rect) lc_written = True right_out = dst_rect_dir / right_path.name cv2.imwrite(str(right_out), right_rect) stats["saved"] += 1 return stats def _discover_session_scan_raw_dirs(self) -> List[Tuple[str, str]]: found: List[Tuple[str, str]] = [] session_dirs = sorted( [p for p in self.source_date_root.iterdir() if p.is_dir() and p.name.lower().startswith("session")] ) for session_dir in session_dirs: if self.session_filter and session_dir.name != self.session_filter: continue scan_dirs = sorted( [p for p in session_dir.iterdir() if p.is_dir() and p.name.lower().startswith("scan")] ) for scan_dir in scan_dirs: raw_dir = scan_dir / "01_raw_images" if raw_dir.is_dir(): found.append((session_dir.name, scan_dir.name)) return found def run_batch(self) -> Dict[str, int]: all_scans = self._discover_session_scan_raw_dirs() if not all_scans: raise RuntimeError(f"No scan folders found under {self.source_date_root}") print(f"[INFO] Found {len(all_scans)} scans under {self.source_date_root}") totals = {"scans": 0, "pairs_total": 0, "saved": 0, "skipped": 0} sessions_seen = set() for session_name, scan_name in all_scans: if session_name not in sessions_seen: self._copy_params_link_for_session(session_name) sessions_seen.add(session_name) scan_stats = self._process_scan(session_name, scan_name) totals["scans"] += 1 totals["pairs_total"] += scan_stats["pairs_total"] totals["saved"] += scan_stats["saved"] totals["skipped"] += scan_stats["skipped"] print( "[INFO] Batch rectification finished: " f"scans={totals['scans']} pairs={totals['pairs_total']} " f"saved={totals['saved']} skipped={totals['skipped']}" ) return totals