pentoo-overlay/dev-python/nvidia-apex/nvidia-apex-22.08.ebuild
2022-11-21 12:23:59 +08:00

51 lines
1.3 KiB
Bash

# Copyright 1999-2022 Gentoo Authors
# Distributed under the terms of the GNU General Public License v2
EAPI=8
PYTHON_COMPAT=( python3_{10..11} )
inherit distutils-r1
#brach 22.04-dev
HASH_COMMIT="${PV}-dev"
DESCRIPTION="NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch"
HOMEPAGE="https://github.com/NVIDIA/apex"
SRC_URI="https://github.com/NVIDIA/apex/archive/${HASH_COMMIT}.tar.gz -> ${P}-gh.tar.gz"
LICENSE=""
SLOT="0"
KEYWORDS="amd64 ~arm64 x86"
IUSE="cuda"
RDEPEND=">=dev-python/tqdm-4.28.1[${PYTHON_USEDEP}]
>=dev-python/numpy-1.15.3[${PYTHON_USEDEP}]
>=dev-python/pyyaml-5.1[${PYTHON_USEDEP}]"
DEPEND="${RDEPEND}"
REQUIRED_USE="${PYTHON_REQUIRED_USE}"
S="${WORKDIR}/apex-${HASH_COMMIT}"
#If you wish to cross-compile for a single specific architecture,
#export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
python_compile() {
if use cuda; then
distutils-r1_python_compile --cpp_ext --cuda_ext
fi
}
python_install() {
if use cuda; then
# disable gpu check
# required for cross-compile and sanbox
export TORCH_CUDA_ARCH_LIST="compute capability"
distutils-r1_python_install --cpp_ext --cuda_ext
fi
}
#FIXME:
#https://github.com/NVIDIA/apex/issues/161
# "No module named 'fused_layer_norm_cuda'"