Thanush commited on
Commit
21c51e5
·
1 Parent(s): 004c8e7

Update ME_LLAMA_MODEL to use the 13b variant for improved performance

Browse files
Files changed (2) hide show
  1. medbot/config.py +1 -1
  2. medbot/model.py +2 -0
medbot/config.py CHANGED
@@ -1,2 +1,2 @@
1
- ME_LLAMA_MODEL = "clinicalnlplab/me-llama"
2
  FALLBACK_MODEL = "meta-llama/Llama-2-7b-chat-hf"
 
1
+ ME_LLAMA_MODEL = "clinicalnlplab/me-llama-13b"
2
  FALLBACK_MODEL = "meta-llama/Llama-2-7b-chat-hf"
medbot/model.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from .config import ME_LLAMA_MODEL, FALLBACK_MODEL
 
4
 
5
  class ModelManager:
6
  def __init__(self):
@@ -11,6 +12,7 @@ class ModelManager:
11
  if self.model is not None and self.tokenizer is not None:
12
  return
13
  try:
 
14
  self.tokenizer = AutoTokenizer.from_pretrained(ME_LLAMA_MODEL, trust_remote_code=True)
15
  self.model = AutoModelForCausalLM.from_pretrained(
16
  ME_LLAMA_MODEL,
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from .config import ME_LLAMA_MODEL, FALLBACK_MODEL
4
+ import logging
5
 
6
  class ModelManager:
7
  def __init__(self):
 
12
  if self.model is not None and self.tokenizer is not None:
13
  return
14
  try:
15
+ logging.info(f"ME_LLAMA_MODEL type: {type(ME_LLAMA_MODEL)}, value: {ME_LLAMA_MODEL}")
16
  self.tokenizer = AutoTokenizer.from_pretrained(ME_LLAMA_MODEL, trust_remote_code=True)
17
  self.model = AutoModelForCausalLM.from_pretrained(
18
  ME_LLAMA_MODEL,