File size: 2,490 Bytes
2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 2f38e4a fa48fc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from fastapi import FastAPI, File, UploadFile, Response
from transformers import ParlerTTSForConditionalGeneration, AutoTokenizer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from llama_cpp import Llama
import torch
import soundfile as sf
import io
import os
from pydantic import BaseModel
app = FastAPI()
# Load models
if os.path.exists("./models/tts_model"):
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("./models/tts_model")
tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
else:
tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1")
tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
# SST and LLM loading remains unchanged
if os.path.exists("./models/sst_model"):
sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
else:
sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
sst_processor = Wav2Vec2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
if os.path.exists("./models/llama.gguf"):
llm = Llama("./models/llama.gguf")
else:
raise FileNotFoundError("Please upload llama.gguf to models/ directory")
# Request models and endpoints remain unchanged
class TTSRequest(BaseModel):
text: str
class LLMRequest(BaseModel):
prompt: str
@app.post("/tts")
async def tts_endpoint(request: TTSRequest):
text = request.text
inputs = tts_tokenizer(text, return_tensors="pt")
with torch.no_grad():
audio = tts_model.generate(**inputs)
audio = audio.squeeze().cpu().numpy()
buffer = io.BytesIO()
sf.write(buffer, audio, 22050, format="WAV")
buffer.seek(0)
return Response(content=buffer.getvalue(), media_type="audio/wav")
@app.post("/sst")
async def sst_endpoint(file: UploadFile = File(...)):
audio_bytes = await file.read()
audio, sr = sf.read(io.BytesIO(audio_bytes))
inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
logits = sst_model(inputs.input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = sst_processor.batch_decode(predicted_ids)[0]
return {"text": transcription}
@app.post("/llm")
async def llm_endpoint(request: LLMRequest):
prompt = request.prompt
output = llm(prompt, max_tokens=50)
return {"text": output["choices"][0]["text"]} |