Welcome to blox’s documentation!¶
blox is a functional and lightweight neural network library for JAX.
The entire mental model fits in one line:
outputs, params = model(params, inputs)
Parameters go in, outputs and updated parameters come out. Because state flows explicitly through your code, all JAX transformations work out of the box. No wrappers, no decorators, no surprises.
Getting Started
Key Concepts
Advanced Topics
- Checkpointable Training on Google Colab
- 1. Setup
- 2. Configuration
- 3. Utilities
- 4. Resumable Dataset
- 5. Model
- 6. Training Step
- 7. Initialize and Restore
- 8. Training Loop
- 9. Evaluation
- 10. Summary
- Two Approaches
- 1. Define an MLP Module
- 2. Explore the Graph
- 3. The LoRA Pattern
- 4. Apply LoRA to Selected Layers
- 5. Initialize LoRA Parameters
- 6. Freeze Base Weights
- 7. Training with LoRA
- 8. Verify Only LoRA Weights Changed
- 9. Merging LoRA Weights (Optional)
- Create a LoRA-aware MLP
- Benefits of the LoRA-aware Approach
- Toggling LoRA
- Sharp Bits in blox
- 1. The Params Container
- 2. RNG Handling in Parallel Contexts
- 3. Init vs Runtime
- 4. Graph and Module Modifications
- 5. Summary