Tomtom84 commited on
Commit
10540d6
Β·
verified Β·
1 Parent(s): 36afd5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -3,6 +3,7 @@ import os, json, asyncio, torch
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
 
6
  from snac import SNAC
7
 
8
  # ── 0.Β HF‑Auth & Device ──────────────────────────────────────────────
@@ -114,7 +115,10 @@ async def tts(ws: WebSocket):
114
  use_cache=True,
115
  return_dict_in_generate=True,
116
  )
117
- past = out.past_key_values
 
 
 
118
  newtok = out.sequences[0,-out.num_generated_tokens:].tolist()
119
 
120
  for t in newtok:
 
3
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
4
  from huggingface_hub import login
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessor
6
+ from transformers.generation.utils import Cache
7
  from snac import SNAC
8
 
9
  # ── 0.Β HF‑Auth & Device ──────────────────────────────────────────────
 
115
  use_cache=True,
116
  return_dict_in_generate=True,
117
  )
118
+ pkv = out.past_key_values
119
+ if isinstance(pkv, Cache):
120
+ pkv = pkv.to_legacy()
121
+ past_kvs = pkv
122
  newtok = out.sequences[0,-out.num_generated_tokens:].tolist()
123
 
124
  for t in newtok: