File size: 4,125 Bytes
a616a2e
dd4c06b
4b18df1
bab5632
8607936
0bbcfe0
dd4c06b
7f92c21
 
 
bca5261
7f92c21
 
a4e4751
 
7f92c21
d1c3a70
a4e4751
a28c209
8607936
52cfee9
7f92c21
8607936
 
 
7f92c21
8607936
 
 
 
 
 
 
 
 
 
 
 
bca5261
7f92c21
 
e648c2d
 
 
a28c209
e648c2d
 
5576fae
8e2fde3
5576fae
 
7f92c21
bab5632
 
 
 
 
 
 
 
 
e648c2d
 
7f92c21
 
a28c209
7f92c21
 
 
 
 
 
 
 
 
 
6e40332
0bbcfe0
 
7f92c21
 
 
 
0bbcfe0
6e40332
 
a28c209
be20f8e
7f92c21
e828a9f
7f92c21
 
6e40332
 
 
 
 
8607936
7f92c21
 
 
 
6e40332
a616a2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import spaces
import gradio as gr
import os
import orjson
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, AutoModelForCausalLM, AutoTokenizer

transcribe_model = None
proofread_model = None

@spaces.GPU(duration=60)
def transcribe_audio(audio):
    global transcribe_model
    if audio is None:
        return "Please upload an audio file."
    if transcribe_model is None:
        return "Please select a model."

    device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    processor = AutoProcessor.from_pretrained(transcribe_model)

    pipe = pipeline(
        "automatic-speech-recognition",
        model=transcribe_model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        max_new_tokens=128,
        chunk_length_s=25,
        batch_size=16,
        torch_dtype=torch_dtype,
        device=device,
    )

    result = pipe(audio)
    return result["text"]

@spaces.GPU(duration=120)
def proofread(text):
    global proofread_model
    if text is None:
        return "Please provide the transcribed text for proofreading."
    
    device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    messages = [
        {"role": "system", "content": "用繁體中文語體文整理這段文字,在最後加上整段文字的重點。"},
        {"role": "user", "content": text},
    ]
    pipe = pipeline("text-generation", model=proofread_model)
    llm_output = pipe(messages)
    
    # Extract the generated text
    generated_text = llm_output[0]['generated_text']

    # Extract the assistant's content
    assistant_content = next(item['content'] for item in generated_text if item['role'] == 'assistant')

    proofread_text = assistant_content
    return proofread_text

def load_models(transcribe_model_id, proofread_model_id):
    global transcribe_model, proofread_model
    device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    transcribe_model = AutoModelForSpeechSeq2Seq.from_pretrained(
        transcribe_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
    )
    transcribe_model.to(device)
    
    proofread_model = AutoModelForCausalLM.from_pretrained(proofread_model_id)
    proofread_model.to(device)

with gr.Blocks() as demo:
    gr.Markdown("""
                # Audio Transcription and Proofreading
                1. Select models for transcription and proofreading
                2. Upload an audio file (Wait for the file to be fully loaded first)
                3. Transcribe the audio
                4. Proofread the transcribed text
                """)

    with gr.Row():
        transcribe_model_dropdown = gr.Dropdown(choices=["openai/whisper-large-v2", "alvanlii/whisper-small-cantonese"], value="alvanlii/whisper-small-cantonese", label="Select Transcription Model")
        proofread_model_dropdown = gr.Dropdown(choices=["hfl/llama-3-chinese-8b-instruct-v3"], value="hfl/llama-3-chinese-8b-instruct-v3", label="Select Proofreading Model")
        load_button = gr.Button("Load Models")

    audio = gr.Audio(sources="upload", type="filepath")
    
    transcribe_button = gr.Button("Transcribe")
    transcribed_text = gr.Textbox(label="Transcribed Text")
    
    proofread_button = gr.Button("Proofread")
    proofread_output = gr.Textbox(label="Proofread Text")

    load_button.click(load_models, inputs=[transcribe_model_dropdown, proofread_model_dropdown])
    transcribe_button.click(transcribe_audio, inputs=audio, outputs=transcribed_text)
    proofread_button.click(proofread, inputs=transcribed_text, outputs=proofread_output)
    transcribed_text.change(proofread, inputs=transcribed_text, outputs=proofread_output)

demo.launch()