Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
from .make_model import make_model | |
hparams_dict = { | |
'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53', | |
'DATASET': 'recanvo', | |
'MAX_DURATION': 4, | |
'SAMPLING_RATE': 16_000, | |
'OUTPUT_HIDDEN_STATES': True, | |
'CLASSIFIER_NAME': 'multilevel', | |
'CLASSIFIER_PROJ_SIZE': 256, | |
'NUM_LABELS': 3, | |
'LABEL_WEIGHTS': [1.0], | |
'LOSS': 'cross-entropy', | |
'GPU_ID': 0, | |
'RETURN_RAW_ARRAY': False, | |
} | |
hparams = argparse.Namespace(**hparams_dict) | |
def get_behaviour_model(classifier_weights_path, device): | |
state_dict = torch.load(classifier_weights_path, map_location=device) | |
model = make_model(hparams) | |
model.classifier.load_state_dict(state_dict) | |
model.eval() | |
return model |