avi292423 commited on
Commit
abf9bcf
·
verified ·
1 Parent(s): 8b22f56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -6
app.py CHANGED
@@ -1,16 +1,13 @@
1
  from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from models.model_wav2vec import Wav2VecIntent
5
  from huggingface_hub import hf_hub_download
6
  import torch
7
  import soundfile as sf
8
- import numpy as np
9
  import librosa
10
 
11
  app = FastAPI()
12
 
13
- # Enable CORS for all origins (so your frontend can call the API)
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
@@ -19,7 +16,6 @@ app.add_middleware(
19
  allow_headers=["*"],
20
  )
21
 
22
- # Download model from Hugging Face
23
  MODEL_PATH = hf_hub_download(repo_id="avi292423/speech-intent-recognition-project", filename="wav2vec_best_model.pt")
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
@@ -36,7 +32,7 @@ label_map = {
36
  index_to_label = {v: k for k, v in label_map.items()}
37
 
38
  num_classes = 31
39
- pretrained_model = "facebook/wav2vec2-large" # Use large model
40
  model = Wav2VecIntent(num_classes=num_classes, pretrained_model=pretrained_model).to(device)
41
  state_dict = torch.load(MODEL_PATH, map_location=device)
42
  model.load_state_dict(state_dict)
@@ -49,7 +45,6 @@ async def predict(file: UploadFile = File(...)):
49
  f.write(audio_bytes)
50
  audio, sample_rate = sf.read("temp.wav")
51
  if sample_rate != 16000:
52
- # Resample to 16kHz
53
  audio = librosa.resample(audio.astype(float), orig_sr=sample_rate, target_sr=16000)
54
  waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
55
  with torch.no_grad():
 
1
  from fastapi import FastAPI, File, UploadFile
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from models.model_wav2vec import Wav2VecIntent
4
  from huggingface_hub import hf_hub_download
5
  import torch
6
  import soundfile as sf
 
7
  import librosa
8
 
9
  app = FastAPI()
10
 
 
11
  app.add_middleware(
12
  CORSMiddleware,
13
  allow_origins=["*"],
 
16
  allow_headers=["*"],
17
  )
18
 
 
19
  MODEL_PATH = hf_hub_download(repo_id="avi292423/speech-intent-recognition-project", filename="wav2vec_best_model.pt")
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
 
32
  index_to_label = {v: k for k, v in label_map.items()}
33
 
34
  num_classes = 31
35
+ pretrained_model = "facebook/wav2vec2-large"
36
  model = Wav2VecIntent(num_classes=num_classes, pretrained_model=pretrained_model).to(device)
37
  state_dict = torch.load(MODEL_PATH, map_location=device)
38
  model.load_state_dict(state_dict)
 
45
  f.write(audio_bytes)
46
  audio, sample_rate = sf.read("temp.wav")
47
  if sample_rate != 16000:
 
48
  audio = librosa.resample(audio.astype(float), orig_sr=sample_rate, target_sr=16000)
49
  waveform = torch.tensor(audio, dtype=torch.float32).unsqueeze(0).to(device)
50
  with torch.no_grad():