File size: 3,340 Bytes
9471255
fa15dd8
5cb7e51
fa15dd8
2685e79
5cb7e51
5ca9307
105a910
fa15dd8
5cb7e51
fa15dd8
5ca9307
 
 
 
5cb7e51
105a910
ba6451b
 
105a910
 
 
 
 
 
 
 
 
 
 
ba6451b
5ca9307
ba6451b
 
 
 
 
 
2685e79
ba6451b
 
 
5cb7e51
5ca9307
fa15dd8
5cb7e51
 
 
fa15dd8
 
 
5ca9307
 
5cb7e51
 
 
5ca9307
5cb7e51
fa15dd8
5ca9307
5cb7e51
 
 
9471255
5cb7e51
5ca9307
f6f6edc
 
5cb7e51
 
 
 
f6f6edc
 
 
5cb7e51
 
f6f6edc
5cb7e51
f6f6edc
5cb7e51
 
105a910
9471255
5cb7e51
 
9471255
5cb7e51
105a910
5cb7e51
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
from moviepy.editor import VideoFileClip
from speechbrain.pretrained import EncoderClassifier
import torchaudio
import requests
import os
import torch
import yt_dlp

CLASSIFIER = "Jzuluaga/accent-id-commonaccent_xlsr-en-english"

def get_default_device():
    """Return the default device (cuda if available, else cpu)."""
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def download_video(url):
    """Download video from YouTube or direct MP4 URL using yt_dlp or requests."""
    try:
        if "youtube.com" in url or "youtu.be" in url:
            output_path = "temp_video.%(ext)s"
            ydl_opts = {
                'format': 'best[ext=mp4]/best',
                'outtmpl': output_path,
                'quiet': True,
                'noplaylist': True,
            }
            with yt_dlp.YoutubeDL(ydl_opts) as ydl:
                info_dict = ydl.extract_info(url, download=True)
                downloaded_path = output_path.replace("%(ext)s", info_dict['ext'])
            return downloaded_path
        else:
            # Direct MP4 file download
            local_filename = "temp_video.mp4"
            with requests.get(url, stream=True) as r:
                r.raise_for_status()
                with open(local_filename, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            return local_filename
    except Exception as e:
        raise RuntimeError(f"Failed to download video: {e}")

def extract_audio(video_path):
    """Extract audio from video and save as WAV file."""
    clip = VideoFileClip(video_path)
    audio_path = "temp_audio.wav"
    clip.audio.write_audiofile(audio_path, logger=None)
    clip.close()
    return audio_path

def classify_accent(audio_path):
    """Classify English accent from audio file using SpeechBrain model."""
    device = get_default_device()
    classifier = EncoderClassifier.from_hparams(
        source=CLASSIFIER,
        savedir="pretrained_models/accent_classifier",
        run_opts={"device": str(device)}
    )
    waveform, sample_rate = torchaudio.load(audio_path)
    prediction = classifier.classify_batch(waveform.to(device))
    predicted_accent = prediction[3][0]
    confidence = prediction[1].exp().max().item() * 100
    return predicted_accent, f"{confidence:.2f}%"

def process_video(url):
    """Main processing pipeline: download video, extract audio, classify accent."""
    video_path = None
    audio_path = None
    try:
        video_path = download_video(url)
        audio_path = extract_audio(video_path)
        accent, confidence = classify_accent(audio_path)
        return accent, confidence
    except Exception as e:
        return f"Error: {e}", ""
    finally:
        for f in [video_path, audio_path]:
            if f and os.path.exists(f):
                os.remove(f)

iface = gr.Interface(
    fn=process_video,
    inputs=gr.Textbox(label="Enter Public Video URL (YouTube or direct MP4 link)"),
    outputs=[
        gr.Textbox(label="Detected Accent"),
        gr.Textbox(label="Confidence Score")
    ],
    title="English Accent Classifier",
    description="Paste a public video URL (YouTube or MP4) to detect the English accent and confidence score."
)

if __name__ == "__main__":
    iface.launch()