from funasr import AutoModel from funasr.utils.postprocess_utils import rich_transcription_postprocess from modelscope import snapshot_download import io import os import tempfile import json from typing import Optional import torch import gradio as gr # 添加Gradio库 from config import model_config from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import StreamingResponse, Response import uvicorn device = "cuda:0" if torch.cuda.is_available() else "cpu" model_dir = snapshot_download(model_config['model_dir']) # 初始化模型 model = AutoModel( model=model_dir, trust_remote_code=False, remote_code="./model.py", vad_model="fsmn-vad", vad_kwargs={"max_single_segment_time": 30000}, ncpu=4, batch_size=1, hub="ms", device=device, ) def transcribe_audio(file_path, vad_model="fsmn-vad", vad_kwargs='{"max_single_segment_time": 30000}', ncpu=4, batch_size=1, language="auto", use_itn=True, batch_size_s=60, merge_vad=True, merge_length_s=15, batch_size_threshold_s=50, hotword=" ", spk_model="cam++", ban_emo_unk=False): try: # 将字符串转换为字典 vad_kwargs = json.loads(vad_kwargs) # 使用文件路径作为输入 temp_file_path = file_path # 生成结果 res = model.generate( input=temp_file_path, # 使用文件路径作为输入 cache={}, language=language, use_itn=use_itn, batch_size_s=batch_size_s, merge_vad=merge_vad, merge_length_s=merge_length_s, batch_size_threshold_s=batch_size_threshold_s, hotword=hotword, spk_model=spk_model, ban_emo_unk=ban_emo_unk ) # 处理结果 text = rich_transcription_postprocess(res[0]["text"]) return text except Exception as e: # 捕获异常并返回错误信息 return str(e) # 创建Gradio界面 inputs = [ gr.Audio(type="filepath"), # 设置为'filepath'来支持文件路径 gr.Textbox(value="fsmn-vad", label="VAD Model"), gr.Textbox(value='{"max_single_segment_time": 30000}', label="VAD Kwargs"), gr.Slider(1, 8, value=4, step=1, label="NCPU"), gr.Slider(1, 10, value=1, step=1, label="Batch Size"), gr.Textbox(value="auto", label="Language"), gr.Checkbox(value=True, label="Use ITN"), gr.Slider(30, 120, value=60, step=1, label="Batch Size (seconds)"), gr.Checkbox(value=True, label="Merge VAD"), gr.Slider(5, 60, value=15, step=1, label="Merge Length (seconds)"), gr.Slider(10, 100, value=50, step=1, label="Batch Size Threshold (seconds)"), gr.Textbox(value=" ", label="Hotword"), gr.Textbox(value="cam++", label="Speaker Model"), gr.Checkbox(value=False, label="Ban Emotional Unknown"), ] outputs = gr.Textbox(label="Transcription") gr.Interface( fn=transcribe_audio, inputs=inputs, outputs=outputs, title="ASR Transcription with FunASR" ).launch() class SynthesizeResponse(Response): media_type = 'text/plain' app = FastAPI() @app.post('/asr', response_class=SynthesizeResponse) async def generate( file: UploadFile = File(...), vad_model: str = Form("fsmn-vad"), vad_kwargs: str = Form('{"max_single_segment_time": 30000}'), ncpu: int = Form(4), batch_size: int = Form(1), language: str = Form("auto"), use_itn: bool = Form(True), batch_size_s: int = Form(60), merge_vad: bool = Form(True), merge_length_s: int = Form(15), batch_size_threshold_s: int = Form(50), hotword: Optional[str] = Form(" "), spk_model: str = Form("cam++"), ban_emo_unk: bool = Form(False), ) -> StreamingResponse: try: # 将字符串转换为字典 vad_kwargs = json.loads(vad_kwargs) # 创建临时文件并保存上传的音频文件 with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: temp_file_path = temp_file.name input_wav_bytes = await file.read() temp_file.write(input_wav_bytes) try: # 初始化模型 model = AutoModel( model=model_dir, trust_remote_code=False, remote_code="./model.py", vad_model=vad_model, vad_kwargs=vad_kwargs, ncpu=ncpu, batch_size=batch_size, hub="ms", device=device, ) # 生成结果 res = model.generate( input=temp_file_path, # 使用临时文件路径作为输入 cache={}, language=language, use_itn=use_itn, batch_size_s=batch_size_s, merge_vad=merge_vad, merge_length_s=merge_length_s, batch_size_threshold_s=batch_size_threshold_s, hotword=hotword, spk_model=spk_model, ban_emo_unk=ban_emo_unk ) # 处理结果 text = rich_transcription_postprocess(res[0]["text"]) # 返回结果 return StreamingResponse(io.BytesIO(text.encode('utf-8')), media_type="text/plain") finally: # 确保在处理完毕后删除临时文件 if os.path.exists(temp_file_path): os.remove(temp_file_path) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/root") async def read_root(): return {"message": "Hello World"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)