codewithdark commited on
Commit
5ec840c
·
verified ·
1 Parent(s): bee08bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -4
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
- from transformers import AutoModel,AutoConfig, AutoTokenizer
4
  import torch
5
 
6
  # Initialize Hugging Face Inference API client
@@ -9,8 +9,7 @@ hf_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
9
  # Load the second model
10
  local_model_name = "codewithdark/latent-recurrent-depth-lm"
11
  tokenizer = AutoTokenizer.from_pretrained(local_model_name)
12
- config = AutoConfig.from_pretrained(local_model_name)
13
- model = AutoModel.from_pretrained(local_model_name, config=config, trust_remote_code=True)
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  model.to(device)
16
 
@@ -40,7 +39,7 @@ def generate_response(
40
  response += token
41
  yield response
42
  else:
43
- input_text = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
44
  output = model.generate(input_text, max_length=max_tokens, temperature=temperature, top_p=top_p)
45
  response = tokenizer.decode(output[0], skip_special_tokens=True)
46
  yield response
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from transformers import AutoModel, AutoTokenizer
4
  import torch
5
 
6
  # Initialize Hugging Face Inference API client
 
9
  # Load the second model
10
  local_model_name = "codewithdark/latent-recurrent-depth-lm"
11
  tokenizer = AutoTokenizer.from_pretrained(local_model_name)
12
+ model = AutoModel.from_pretrained(local_model_name, trust_remote_code=True)
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
  model.to(device)
15
 
 
39
  response += token
40
  yield response
41
  else:
42
+ input_text = tokenizer(messages, return_tensors="pt").to(device)
43
  output = model.generate(input_text, max_length=max_tokens, temperature=temperature, top_p=top_p)
44
  response = tokenizer.decode(output[0], skip_special_tokens=True)
45
  yield response