par worstchan
Open source · 64k downloads · 2 likes
Ce modèle, EAT base epoch30 finetune AS2M, est une version affinée du modèle EAT-base après 30 époques d'entraînement supplémentaire sur le jeu de données AS-2M. Il génère des représentations audio plus performantes, optimisées pour des tâches d'analyse audio comme la classification ou la génération de légendes. Contrairement à sa version pré-entraînée, il offre une meilleure précision dans la compréhension des contenus sonores, notamment grâce à un fine-tuning ciblé. Ses principales applications incluent l'extraction de caractéristiques audio, l'analyse de scènes sonores ou encore l'amélioration de modèles de reconnaissance vocale. Ce qui le distingue est son architecture EAT, conçue pour être à la fois efficace et auto-supervisée, permettant une intégration facile dans divers pipelines d'apprentissage automatique.
This is the fine-tuned version of the EAT-base (Epoch 30, Pre-trained Checkpoint), further trained on the AS-2M dataset. Compared to the pre-trained model, this version provides enhanced audio representations and typically yields better performance in downstream audio understanding tasks such as classification and captioning.
For more details on the EAT framework, please refer to the GitHub repository and our paper EAT: Self-Supervised Pre-Training with Efficient Audio Transformer.
You can load and use the model for feature extraction directly via Hugging Face Transformers:
import torchaudio
import torch
import soundfile as sf
import numpy as np
from transformers import AutoModel
model_id = "worstchan/EAT-base_epoch30_finetune_AS2M"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
source_file = "/path/to/input.wav"
target_file = "/path/to/output.npy"
target_length = 1024 # Recommended: 1024 for 10s audio
norm_mean = -4.268
norm_std = 4.569
# Load and resample audio
wav, sr = sf.read(source_file)
waveform = torch.tensor(wav).float().cuda()
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
# Normalize and convert to mel-spectrogram
waveform = waveform - waveform.mean()
mel = torchaudio.compliance.kaldi.fbank(
waveform.unsqueeze(0),
htk_compat=True,
sample_frequency=16000,
use_energy=False,
window_type='hanning',
num_mel_bins=128,
dither=0.0,
frame_shift=10
).unsqueeze(0)
# Pad or truncate
n_frames = mel.shape[1]
if n_frames < target_length:
mel = torch.nn.ZeroPad2d((0, 0, 0, target_length - n_frames))(mel)
else:
mel = mel[:, :target_length, :]
# Normalize
mel = (mel - norm_mean) / (norm_std * 2)
mel = mel.unsqueeze(0).cuda() # shape: [1, 1, T, F]
# Extract features
with torch.no_grad():
feat = model.extract_features(mel)
feat = feat.squeeze(0).cpu().numpy()
np.save(target_file, feat)
print(f"Feature shape: {feat.shape}")
print(f"Saved to: {target_file}")
The model supports both frame-level (~50Hz) and utterance-level (CLS token) representations. See the feature extraction guide for detailed instructions.
If you find this model useful, please consider citing our paper:
@article{chen2024eat,
title={EAT: Self-supervised pre-training with efficient audio transformer},
author={Chen, Wenxi and Liang, Yuzhe and Ma, Ziyang and Zheng, Zhisheng and Chen, Xie},
journal={arXiv preprint arXiv:2401.03497},
year={2024}
}