File size: 2,490 Bytes
2f38e4a
fa48fc0
2f38e4a
 
 
 
 
 
 
 
 
 
fa48fc0
 
 
 
2f38e4a
fa48fc0
 
2f38e4a
fa48fc0
 
 
 
2f38e4a
 
fa48fc0
2f38e4a
fa48fc0
 
2f38e4a
 
 
fa48fc0
2f38e4a
 
 
 
 
 
 
 
 
 
 
fa48fc0
 
2f38e4a
fa48fc0
2f38e4a
 
 
 
 
 
 
 
 
fa48fc0
 
2f38e4a
 
 
 
 
 
fa48fc0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI, File, UploadFile, Response
from transformers import ParlerTTSForConditionalGeneration, AutoTokenizer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from llama_cpp import Llama
import torch
import soundfile as sf
import io
import os
from pydantic import BaseModel

app = FastAPI()

# Load models
if os.path.exists("./models/tts_model"):
    tts_model = ParlerTTSForConditionalGeneration.from_pretrained("./models/tts_model")
    tts_tokenizer = AutoTokenizer.from_pretrained("./models/tts_model")
else:
    tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1")
    tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")

# SST and LLM loading remains unchanged
if os.path.exists("./models/sst_model"):
    sst_model = Wav2Vec2ForCTC.from_pretrained("./models/sst_model")
    sst_processor = Wav2Vec2Processor.from_pretrained("./models/sst_model")
else:
    sst_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
    sst_processor = Wav2Vec2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

if os.path.exists("./models/llama.gguf"):
    llm = Llama("./models/llama.gguf")
else:
    raise FileNotFoundError("Please upload llama.gguf to models/ directory")

# Request models and endpoints remain unchanged
class TTSRequest(BaseModel):
    text: str

class LLMRequest(BaseModel):
    prompt: str

@app.post("/tts")
async def tts_endpoint(request: TTSRequest):
    text = request.text
    inputs = tts_tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        audio = tts_model.generate(**inputs)
    audio = audio.squeeze().cpu().numpy()
    buffer = io.BytesIO()
    sf.write(buffer, audio, 22050, format="WAV")
    buffer.seek(0)
    return Response(content=buffer.getvalue(), media_type="audio/wav")

@app.post("/sst")
async def sst_endpoint(file: UploadFile = File(...)):
    audio_bytes = await file.read()
    audio, sr = sf.read(io.BytesIO(audio_bytes))
    inputs = sst_processor(audio, sampling_rate=sr, return_tensors="pt")
    with torch.no_grad():
        logits = sst_model(inputs.input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = sst_processor.batch_decode(predicted_ids)[0]
    return {"text": transcription}

@app.post("/llm")
async def llm_endpoint(request: LLMRequest):
    prompt = request.prompt
    output = llm(prompt, max_tokens=50)
    return {"text": output["choices"][0]["text"]}