File size: 2,217 Bytes
2fafc94 e607fab 2fafc94 646f8c2 2fafc94 646f8c2 2fafc94 646f8c2 2fafc94 646f8c2 2fafc94 646f8c2 2fafc94 646f8c2 2fafc94 e607fab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
# from langchain import HuggingFaceHub, LLMChain
from langchain.llms import HuggingFacePipeline
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
pipeline,
)
from transformers import LlamaForCausalLM, AutoModelForCausalLM, LlamaTokenizer
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_groq import ChatGroq
from langchain.chat_models import ChatOpenAI
from langchain.llms import HuggingFaceTextGenInference
def get_llm_hf_online(inference_api_url=""):
"""Get LLM using huggingface inference."""
if not inference_api_url: # default api url
inference_api_url = (
"https://api-inference.huggingface.co/models/HuggingFaceH4/zephyr-7b-beta"
)
llm = HuggingFaceTextGenInference(
verbose=True, # Provides detailed logs of operation
max_new_tokens=1024, # Maximum number of token that can be generated.
top_p=0.95, # Threshold for controlling randomness in text generation process.
temperature=0.1,
inference_server_url=inference_api_url,
timeout=10, # Timeout for connection with the url
)
return llm
def get_llm_hf_local(model_path):
"""Get local LLM."""
model = LlamaForCausalLM.from_pretrained(
model_path, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# print('making a pipeline...')
# max_length has typically been deprecated for max_new_tokens
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=1024, # better setting?
model_kwargs={"temperature": 0.1}, # better setting?
)
llm = HuggingFacePipeline(pipeline=pipe)
return llm
def get_llm_openai_chat(model_name, inference_server_url):
"""Get openai-like LLM."""
llm = ChatOpenAI(
model=model_name,
openai_api_key="EMPTY",
openai_api_base=inference_server_url,
max_tokens=1024, # better setting?
temperature=0,
)
return llm
def get_groq_chat(model_name="llama-3.1-70b-versatile"):
llm = ChatGroq(temperature=0, model_name=model_name)
return llm |