NHacker Next
  • new
  • past
  • show
  • ask
  • show
  • jobs
  • submit
How to scale your model: A systems view of LLMs on TPUs (jax-ml.github.io)
3abiton 3 days ago [-]
I am really looking forward for JAX to take over pytorch/cuda over the next years. The whole PTX kerfuffle with Deepseek team shows the value of investing in more low levels approaches to squeeze out the most out of your hardware.
kadushka 3 days ago [-]
Most Pytorch users don’t bother even with the simplest performance optimizations, and you are talking about PTX.
throwaway287391 3 days ago [-]
I like JAX but I'm not sure how an ML framework debate like "JAX vs PyTorch" is relevant to DeepSeek/PTX. The JAX API is at a similar level of abstraction to PyTorch [0]. Both are Python libraries and sit a few layers of abstraction above PTX/CUDA and their TPU equivalents.

[0] Although PyTorch arguably encompasses 2 levels, with both a pure functional library like the JAX API, as well as a "neural network" framework on top of it. Whereas JAX doesn't have the latter and leaves that to separate libraries like Flax.

jdeaton 3 days ago [-]
The interesting thing about this comment is that JAX is actually higher-level even than pytorch generally. Since everything is compiled you just express a logcial program and let the compiler (XLA) worry about the rest.

Are you suggesting that XLA would be where this "lower level" approach would reside since it can do more automatic optimization?

Scene_Cast2 3 days ago [-]
I'm curious, what does paradigmatic JAX look like? Is there an equivalent of picoGPT [1] for JAX?

[1] https://github.com/jaymody/picoGPT/blob/main/gpt2.py

jdeaton 3 days ago [-]
yeah it looks exactly like that file but replace "import numpy as np" with "import jax.numpy as np" :)
achierius 2 days ago [-]
What PTX kerfuffle are you referring to?
saagarjha 3 days ago [-]
You do understand that PTX is part of CUDA right?
lordswork 3 days ago [-]
This has been my bible for performance work internally at Google. Kind of surprised they released it publicly, but I guess they removed all the Gemini-specific details.
memhole 3 days ago [-]
This is awesome! Can't wait to read it. I've been very curious about why we don't hear more about LLMs on TPUs.
jdeaton 3 days ago [-]
Something nice about this guide is that it generally transfers to GPU directly thanks to JAX/XLA.
brap 3 days ago [-]
Not strictly related, but does anyone know why JAX uses tracing and not AST via reflection?
shoyer 3 days ago [-]
The short answer is that tracing is way, way easier to implement in a predictable and reliably performant way. This especially matters for distributed computation and automatic differentiation, two areas where JAX shines.

AST parsing via reflection means your ML compiler needs to re-implement all of Python, which is not a small language. This is a lot of work and hard to do well with abstractions that are not designed for those use-cases. (I believe Julia's whole language auto-diff systems struggle for essential the same reason.)

almostgotcaught 2 days ago [-]
> AST via reflection

I literally am a paid ML compiler engineer and I have no idea what this means. You understand that reflection, ala looking in a mirror is about being about to identify a type's type at runtime. It has nothing to do with the AST.

brap 24 hours ago [-]
Congratulations, but that's not what reflection means.

Wikipedia: "reflection is the ability of a process to examine, introspect, and modify its own structure and behavior."

Would you say inspect.getsource(func) fits the definition of reflection?

Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?

almostgotcaught 23 hours ago [-]
> Would you say inspect.getsource(func) fits the definition of reflection?

I would say that reflection is absolutely meaningless in an an interpreted runtime because you can always query the runtime.

> Would you say ast.parse(inspect.getsource(func)) has something to do with the AST?

It has something to do with the AST but it doesn't have much to do with reflection.

bronxbomber92 2 days ago [-]
mattjjatgoogle 3 days ago [-]
awongh 3 days ago [-]
Here in the thread he says: https://x.com/jacobaustin132/status/1886844724339675340 : `5 years ago, there were many ML architectures, but today, there is (mostly) only one [transformers].`

To what degree is this actually true, and what else is on the horizon that might become as popular as transformers?

swyx 2 days ago [-]
it's quite true. the convergence of all archs to transformers is well documented by karpathy. SSMs were once touted as transformer killers, but increasingly look like just optional supplements.
eamag 3 days ago [-]
Any way to convert this Jekyll site to a PDF?
atomala 3 days ago [-]
There are plans to release a PDF version; need to fix some formatting issues + convert the animated diagrams into static images.
perfobotto 3 days ago [-]
What an amazing write up! Thank you very much!
hassleblad23 3 days ago [-]
Great writeup. Congrats.
whatever1 3 days ago [-]
How do they make these fancy animations?
alevskaya 3 days ago [-]
Nothing fancy. I made these with some pretty simple hand written scripts in javascript rendering to canvas: lots of fiddly little boxes moving around are simpler to script than to hand animate. (If I were to do much more of this I might rewrite these in blender since it has much nicer authoring tooling and export control.)
nicodjimenez 3 days ago [-]
Shameless request for help: if anybody has experience with seq2seq on TPU, and you want to do a cool project to deploy a world class Pytorch image parsing model to TPU (and do this quickly), please contact me immediately for a well paid and interesting job opportunity at nico [at] mathpix.com.
jdeaton 3 days ago [-]
if you're using tpu why are you using pytorch
hustwindmaple1 2 days ago [-]
there is limited TPU support in pytorch via torch_xla
jdeaton 24 hours ago [-]
Sounds limited
Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact
Rendered at 17:00:02 GMT+0000 (Coordinated Universal Time) with Vercel.