Source code for rigid_body_motion.ros.io

from pathlib import Path

import numpy as np


[docs] class RosbagReader: """Reader for motion topics from rosbag files."""
[docs] def __init__(self, bag_file): """Constructor. Parameters ---------- bag_file: str Path to rosbag file. """ self.bag_file = Path(bag_file) self._bag = None
def __enter__(self): import rosbag self._bag = rosbag.Bag(self.bag_file, "r") return self def __exit__(self, exc_type, exc_val, exc_tb): self._bag.close() self._bag = None @staticmethod def _get_msg_type(bag, topic): """Get type of message.""" return bag.get_type_and_topic_info(topic).topics[topic].msg_type def _get_filename(self, output_file, extension): """Get export filename and create folder.""" if output_file is None: folder, filename = self.bag_file.parent, self.bag_file.stem filename = folder / f"{filename}.{extension}" else: folder = output_file.parent filename = output_file folder.mkdir(parents=True, exist_ok=True) return filename @staticmethod def _write_netcdf(ds, filename, dtype="int32"): """Write dataset to netCDF file.""" comp = { "zlib": True, "dtype": dtype, "scale_factor": 0.0001, "_FillValue": np.iinfo(dtype).min, } encoding = {} for v in ds.data_vars: encoding[v] = comp ds.to_netcdf(filename, encoding=encoding)
[docs] def get_topics_and_types(self): """Get topics and corresponding message types included in rosbag. Returns ------- topics: dict Names of topics and corresponding message types included in the rosbag. """ if self._bag is None: raise RuntimeError( "get_topics must be called from within the RosbagReader " "context manager" ) info = self._bag.get_type_and_topic_info() return {k: v[0] for k, v in info[1].items()}
[docs] def load_messages(self, topic): """Load messages from topic as dict. Only nav_msgs/Odometry and geometry_msgs/TransformStamped topics are supported so far. Parameters ---------- topic: str Name of the topic to load. Returns ------- messages: dict Dict containing arrays of timestamps and other message contents. """ from .msg import ( unpack_point_msg, unpack_quaternion_msg, unpack_vector_msg, ) if self._bag is None: raise RuntimeError( "load_messages must be called from within the RosbagReader " "context manager" ) msg_type = self._get_msg_type(self._bag, topic) if msg_type == "nav_msgs/Odometry": arr = np.array( [ ( (msg.header.stamp if msg._has_header else ts).to_sec(), *unpack_point_msg(msg.pose.pose.position), *unpack_quaternion_msg(msg.pose.pose.orientation), *unpack_vector_msg(msg.twist.twist.linear), *unpack_vector_msg(msg.twist.twist.angular), ) for _, msg, ts in self._bag.read_messages(topics=topic) ] ) return_vals = { "timestamps": arr[:, 0], "position": arr[:, 1:4], "orientation": arr[:, 4:8], "linear_velocity": arr[:, 8:11], "angular_velocity": arr[:, 11:], } elif msg_type == "geometry_msgs/TransformStamped": arr = np.array( [ ( (msg.header.stamp if msg._has_header else ts).to_sec(), *unpack_point_msg(msg.transform.translation), *unpack_quaternion_msg(msg.transform.rotation), ) for _, msg, ts in self._bag.read_messages(topics=topic) ] ) return_vals = { "timestamps": arr[:, 0], "position": arr[:, 1:4], "orientation": arr[:, 4:8], } else: raise ValueError(f"Unsupported message type {msg_type}") return return_vals
[docs] def load_dataset(self, topic, cache=False): """Load messages from topic as xarray.Dataset. Only nav_msgs/Odometry and geometry_msgs/TransformStamped topics are supported so far. Parameters ---------- topic: str Name of the topic to load. cache: bool, default False If True, cache the dataset in ``cache/<topic>.nc`` in the same folder as the rosbag. Returns ------- ds: xarray.Dataset Messages as dataset. """ # TODO attrs import pandas as pd import xarray as xr if cache: filepath = ( self.bag_file.parent / "cache" / f"{topic.replace('/', '_')}.nc" ) if not filepath.exists(): self.export(topic, filepath) return xr.open_dataset(filepath) motion = self.load_messages(topic) coords = { "cartesian_axis": ["x", "y", "z"], "quaternion_axis": ["w", "x", "y", "z"], "time": pd.to_datetime(motion["timestamps"], unit="s"), } data_vars = { "position": (["time", "cartesian_axis"], motion["position"]), "orientation": ( ["time", "quaternion_axis"], motion["orientation"], ), } if "linear_velocity" in motion: data_vars.update( ( { "linear_velocity": ( ["time", "cartesian_axis"], motion["linear_velocity"], ), "angular_velocity": ( ["time", "cartesian_axis"], motion["angular_velocity"], ), } ) ) ds = xr.Dataset(data_vars, coords) return ds
[docs] def export(self, topic, output_file=None): """Export messages from topic as netCDF4 file. Parameters ---------- topic: str Topic to read. output_file: str, optional Path to output file. By default, the path to the bag file, but with a different extension depending on the export format. """ ds = self.load_dataset(topic, cache=False) self._write_netcdf(ds, self._get_filename(output_file, "nc"))
[docs] class RosbagWriter: """Writer for motion topics to rosbag files."""
[docs] def __init__(self, bag_file): """Constructor. Parameters ---------- bag_file: str Path to rosbag file. """ self.bag_file = Path(bag_file) self._bag = None
def __enter__(self): import rosbag self._bag = rosbag.Bag(self.bag_file, "w") return self def __exit__(self, exc_type, exc_val, exc_tb): self._bag.close() self._bag = None
[docs] def write_transform_stamped( self, timestamps, translation, rotation, topic, frame, child_frame ): """Write multiple geometry_msgs/TransformStamped messages. Parameters ---------- timestamps: array_like, shape (n_timestamps,) Array of timestamps. translation: array_like, shape (n_timestamps, 3) Array of translations. rotation: array_like, shape (n_timestamps, 4) Array of rotations. topic: str Topic of the messages. frame: str Parent frame of the transform. child_frame: str Child frame of the transform. """ from .msg import make_transform_msg # check timestamps timestamps = np.asarray(timestamps) if timestamps.ndim != 1: raise ValueError("timestamps must be one-dimensional") # check translation translation = np.asarray(translation) if translation.shape != (len(timestamps), 3): raise ValueError( f"Translation must have shape ({len(timestamps)}, 3), " f"got {translation.shape}" ) # check rotation rotation = np.asarray(rotation) if rotation.shape != (len(timestamps), 4): raise ValueError( f"Rotation must have shape ({len(timestamps)}, 4), " f"got {rotation.shape}" ) # write messages to bag for ts, t, r in zip(timestamps, translation, rotation): msg = make_transform_msg(t, r, frame, child_frame, ts) self._bag.write(topic, msg)
[docs] def write_transform_stamped_dataset( self, ds, topic, frame, child_frame, timestamps="time", translation="position", rotation="orientation", ): """Write a dataset as geometry_msgs/TransformStamped messages. Parameters ---------- ds: xarray.Dataset Dataset containing timestamps, translation and rotation topic: str Topic of the messages. frame: str Parent frame of the transform. child_frame: str Child frame of the transform. timestamps: str, default 'time' Name of the dimension containing the timestamps. translation: str, default 'position' Name of the variable containing the translation. rotation: str, default 'orientation' Name of the variable containing the rotation. """ if np.issubdtype(ds[timestamps].dtype, np.datetime64): timestamps = ds[timestamps].astype(float) / 1e9 else: timestamps = ds[timestamps] self.write_transform_stamped( timestamps, ds[translation], ds[rotation], topic, frame, child_frame, )