Install the package from PyPI:
pip install jaxcld
Attach a CLD head to a Whisper model and run inference:
import numpy as np
from cld import ASRModel, CVXNNLangDetectHead
languages = ["en", "hi", "id", "ms", "zh"]
asr = ASRModel.from_pretrained("openai/whisper-small", config={"languages": languages})
head = CVXNNLangDetectHead.load("path/to/whisper-small_trained_cvx_mlp.pkl", asr)
asr.set_lang_detect_head(head)
audio_16k_mono: np.ndarray = ... # shape (T,), 16 kHz mono
pred_langs, pred_texts = asr.predict(audio_16k_mono)
print(pred_langs[0], pred_texts[0])