Hacker News new | past | comments | ask | show | jobs | submit login

What about CPU-only loads? If one wants to write code that'll eventually run in both CPU and GPU but in the short-to-mid term will only be used in CPU? Since JAX natively support CPU (with numpy backend), but CuPy doesn't, this seems like a potential problem for some.



Isn't there a way to dynamically select between numpy and cupy, depending on whether you want cpu or gpu code?


NumPy has a mechanism to dispatch execution to CuPy: https://numpy.org/neps/nep-0018-array-function-protocol.html

Just prepare the input on NumPy or CuPy, and then you can just feed it to NumPy APIs. NumPy functions will handle itself if the input is NumPy ndarray, or dispatch the execution to CuPy if the input is CuPy ndarray.


> Isn't there a way to dynamically select between numpy and cupy, depending on whether you want cpu or gpu code?

CuPy is an (almost) drop-in replacement for NumPy, so the following works surprisingly often:

    if use_cpu:
        import numpy as np
    else:
       import cupy as np


> surprisingly

This is the problem with these kind of methods. It works, until it doesn't in an unknown way.


There is but then you're using two separate libraries, that seems like a fragile point of failure compared to just using jax. But regardless since jax will use different backends anyway, it's arguably not any worse (but it ends up being your responsibility to ensure correctness as opposed to the jax team).




Consider applying for YC's Spring batch! Applications are open till Feb 11.

Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: