I am really looking forward for JAX to take over pytorch/cuda over the next years. The whole PTX kerfuffle with Deepseek team shows the value of investing in more low levels approaches to squeeze out the most out of your hardware.
kadushka 35 days ago [-]
Most Pytorch users don’t bother even with the simplest performance optimizations, and you are talking about PTX.
throwaway287391 35 days ago [-]
I like JAX but I'm not sure how an ML framework debate like "JAX vs PyTorch" is relevant to DeepSeek/PTX. The JAX API is at a similar level of abstraction to PyTorch [0]. Both are Python libraries and sit a few layers of abstraction above PTX/CUDA and their TPU equivalents.
[0] Although PyTorch arguably encompasses 2 levels, with both a pure functional library like the JAX API, as well as a "neural network" framework on top of it. Whereas JAX doesn't have the latter and leaves that to separate libraries like Flax.
jdeaton 35 days ago [-]
The interesting thing about this comment is that JAX is actually higher-level even than pytorch generally. Since everything is compiled you just express a logcial program and let the compiler (XLA) worry about the rest.
Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?
Scene_Cast2 35 days ago [-]
I'm curious, what does paradigmatic JAX look like? Is there an equivalent of picoGPT [1] for JAX?
yeah it looks exactly like that file but replace "import numpy as np" with "import jax.numpy as np" :)
achierius 34 days ago [-]
What PTX kerfuffle are you referring to?
saagarjha 35 days ago [-]
You do understand that PTX is part of CUDA right?
lordswork 35 days ago [-]
This has been my bible for performance work internally at Google. Kind of surprised they released it publicly, but I guess they removed all the Gemini-specific details.
memhole 35 days ago [-]
This is awesome! Can't wait to read it. I've been very curious about why we don't hear more about LLMs on TPUs.
jdeaton 35 days ago [-]
Something nice about this guide is that it generally transfers to GPU directly thanks to JAX/XLA.
brap 35 days ago [-]
Not strictly related, but does anyone know why JAX uses tracing and not AST via reflection?
shoyer 35 days ago [-]
The short answer is that tracing is way, way easier to implement in a predictable and reliably performant way. This especially matters for distributed computation and automatic differentiation, two areas where JAX shines.
AST parsing via reflection means your ML compiler needs to re-implement all of Python, which is not a small language. This is a lot of work and hard to do well with abstractions that are not designed for those use-cases. (I believe Julia's whole language auto-diff systems struggle for essential the same reason.)
almostgotcaught 35 days ago [-]
> AST via reflection
I literally am a paid ML compiler engineer and I have no idea what this means. You understand that reflection, ala looking in a mirror is about being about to identify a type's type at runtime. It has nothing to do with the AST.
brap 33 days ago [-]
Congratulations, but that's not what reflection means.
Wikipedia: "reflection is the ability of a process to examine, introspect, and modify its own structure and behavior."
Would you say inspect.getsource(func) fits the definition of reflection?
Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?
almostgotcaught 33 days ago [-]
> Would you say inspect.getsource(func) fits the definition of reflection?
I would say that reflection is absolutely meaningless in an an interpreted runtime because you can always query the runtime.
> Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?
It has something to do with the AST but it doesn't have much to do with reflection.
To what degree is this actually true, and what else is on the horizon that might become as popular as transformers?
swyx 35 days ago [-]
it's quite true. the convergence of all archs to transformers is well documented by karpathy. SSMs were once touted as transformer killers, but increasingly look like just optional supplements.
perfobotto 35 days ago [-]
What an amazing write up! Thank you very much!
eamag 35 days ago [-]
Any way to convert this Jekyll site to a PDF?
atomala 35 days ago [-]
There are plans to release a PDF version; need to fix some formatting issues + convert the animated diagrams into static images.
hassleblad23 35 days ago [-]
Great writeup. Congrats.
whatever1 35 days ago [-]
How do they make these fancy animations?
alevskaya 35 days ago [-]
Nothing fancy. I made these with some pretty simple hand written scripts in javascript rendering to canvas: lots of fiddly little boxes moving around are simpler to script than to hand animate. (If I were to do much more of this I might rewrite these in blender since it has much nicer authoring tooling and export control.)
nicodjimenez 35 days ago [-]
Shameless request for help: if anybody has experience with seq2seq on TPU, and you want to do a cool project to deploy a world class Pytorch image parsing model to TPU (and do this quickly), please contact me immediately for a well paid and interesting job opportunity at nico [at] mathpix.com.
jdeaton 35 days ago [-]
if you're using tpu why are you using pytorch
hustwindmaple1 34 days ago [-]
there is limited TPU support in pytorch via torch_xla
jdeaton 33 days ago [-]
Sounds limited
Rendered at 23:28:20 GMT+0000 (Coordinated Universal Time) with Vercel.
[0] Although PyTorch arguably encompasses 2 levels, with both a pure functional library like the JAX API, as well as a "neural network" framework on top of it. Whereas JAX doesn't have the latter and leaves that to separate libraries like Flax.
Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?
[1] https://github.com/jaymody/picoGPT/blob/main/gpt2.py
AST parsing via reflection means your ML compiler needs to re-implement all of Python, which is not a small language. This is a lot of work and hard to do well with abstractions that are not designed for those use-cases. (I believe Julia's whole language auto-diff systems struggle for essential the same reason.)
I literally am a paid ML compiler engineer and I have no idea what this means. You understand that reflection, ala looking in a mirror is about being about to identify a type's type at runtime. It has nothing to do with the AST.
Wikipedia: "reflection is the ability of a process to examine, introspect, and modify its own structure and behavior."
Would you say inspect.getsource(func) fits the definition of reflection?
Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?
I would say that reflection is absolutely meaningless in an an interpreted runtime because you can always query the runtime.
> Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?
It has something to do with the AST but it doesn't have much to do with reflection.
https://docs.scala-lang.org/scala3/reference/metaprogramming...
To what degree is this actually true, and what else is on the horizon that might become as popular as transformers?