Trax — Deep Learning with Clear Code and Speed
https://github.com/google/trax
826 forks.
8,304 stars.
125 open issues.
Recent commits:
- [trax] Explicitly set `jax_pmap_shmap_merge=False`.`trainer._multi_device_update_fn` uses `jax.pmap` and when `jax_pmap_shmap_merge=True`, `jax.pmap` requires inputs be explicitly sharded as the underlying `jax.jit` expects.This would need to be fixed if `jax_pmap_shmap_merge=True`.PiperOrigin-RevId: 811810947, Copybara-Service
- Remove more dtype canonicalization.Notable changes:* Remove dtype canonicalization from many lax dtype rules. Dtype rules should always talk about post-canonicalized types.* Use `check_and_canonicalize_user_dtype` instead of `canonicalize_dtype` in many user-callable lax functions.* Change some calls to lax functions to pass canonicalized types to avoid new user warnings.* Move `_convert_and_clip_integer` into `random.py`, which is its only user.* In passing, rename `_input_dtype` to `input_dtype` because it is exported from its enclosing module.PiperOrigin-RevId: 800201348, Copybara-Service
- Configure tests to explicitly use jax_threefry_partitionable=False.See https://github.com/jax-ml/jax/discussions/18480PiperOrigin-RevId: 746173050, Copybara-Service
- Automated Code ChangePiperOrigin-RevId: 724370877, Copybara-Service
- Avoid use of deprecated xla_bridge.get_backend().live_buffers()xla_bridge.get_backend is deprecated, and the public API for this is jax.live_arrays(). This is a drop-in replacement with no change of behavior.PiperOrigin-RevId: 723599353, Copybara-Service
Was this post helpful?
Let us know if you liked the post. That’s the only way we can improve.