Spaces:
No application file
No application file
import warnings | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
from transformers.models.mistral.modeling_mistral import MistralForCausalLM | |
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast | |
model_name = "mistralai/Mistral-7B-Instruct-v0.2" | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") | |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2") | |
from langchain.llms.base import LLM | |
from langchain.callbacks.manager import CallbackManagerForLLMRun | |
from typing import Optional, List, Mapping, Any | |
class CustomLLMMistral(LLM): | |
model: MistralForCausalLM | |
tokenizer: LlamaTokenizerFast | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call(self, prompt: str, stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None) -> str: | |
messages = [ | |
{"role": "user", "content": prompt}, | |
] | |
encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt") | |
model_inputs = encodeds.to(self.model.device) | |
generated_ids = self.model.generate(model_inputs, max_new_tokens=512, do_sample=True, pad_token_id=tokenizer.eos_token_id, top_k=4, temperature=0.7) | |
decoded = self.tokenizer.batch_decode(generated_ids) | |
output = decoded[0].split("[/INST]")[1].replace("</s>", "").strip() | |
if stop is not None: | |
for word in stop: | |
output = output.split(word)[0].strip() | |
while not output.endswith("```"): | |
output += "`" | |
return output | |
def _identifying_params(self) -> Mapping[str, Any]: | |
return {"model": self.model} | |
llm = CustomLLMMistral(model=model, tokenizer=tokenizer) | |
import numexpr as ne | |
from langchain.tools import WikipediaQueryRun, BaseTool | |
from langchain_community.utilities import WikipediaAPIWrapper | |
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=2500)) | |
print(wikipedia.run("Deep Learning")) | |
wikipedia_tool = Tool( | |
name="wikipedia", | |
description="Never search for more than one concept at a single step. If you need to compare two concepts, search for each one individually. Syntax: string with a simple concept", | |
func=wikipedia.run | |
) | |
class Calculator(BaseTool): | |
name = "calculator" | |
description = "Use this tool for math operations. It requires numexpr syntax. Use it always you need to solve any math operation. Be sure syntax is correct." | |
def _run(self, expression: str): | |
try: | |
return ne.evaluate(expression).item() | |
except Exception: | |
return "This is not a numexpr valid syntax. Try a different syntax." | |
def _arun(self, radius: int): | |
raise NotImplementedError("This tool does not support async") | |
calculator_tool = Calculator() | |
calculator_tool.run("2+3") |