Triomphanrt commited on
Commit
ddadf05
·
verified ·
1 Parent(s): e065cc1

Delete mistral_model.py

Browse files
Files changed (1) hide show
  1. mistral_model.py +0 -85
mistral_model.py DELETED
@@ -1,85 +0,0 @@
1
- import warnings
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
- from transformers.models.mistral.modeling_mistral import MistralForCausalLM
5
- from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
6
-
7
- model_name = "mistralai/Mistral-7B-Instruct-v0.2"
8
-
9
- quantization_config = BitsAndBytesConfig(load_in_4bit=True)
10
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
11
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
12
-
13
- from langchain.llms.base import LLM
14
- from langchain.callbacks.manager import CallbackManagerForLLMRun
15
- from typing import Optional, List, Mapping, Any
16
-
17
- class CustomLLMMistral(LLM):
18
- model: MistralForCausalLM
19
- tokenizer: LlamaTokenizerFast
20
-
21
- @property
22
- def _llm_type(self) -> str:
23
- return "custom"
24
-
25
- def _call(self, prompt: str, stop: Optional[List[str]] = None,
26
- run_manager: Optional[CallbackManagerForLLMRun] = None) -> str:
27
-
28
- messages = [
29
- {"role": "user", "content": prompt},
30
- ]
31
-
32
- encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
33
- model_inputs = encodeds.to(self.model.device)
34
-
35
- generated_ids = self.model.generate(model_inputs, max_new_tokens=512, do_sample=True, pad_token_id=tokenizer.eos_token_id, top_k=4, temperature=0.7)
36
- decoded = self.tokenizer.batch_decode(generated_ids)
37
-
38
- output = decoded[0].split("[/INST]")[1].replace("</s>", "").strip()
39
-
40
- if stop is not None:
41
- for word in stop:
42
- output = output.split(word)[0].strip()
43
-
44
- while not output.endswith("```"):
45
- output += "`"
46
-
47
- return output
48
-
49
- @property
50
- def _identifying_params(self) -> Mapping[str, Any]:
51
- return {"model": self.model}
52
-
53
- llm = CustomLLMMistral(model=model, tokenizer=tokenizer)
54
-
55
- import numexpr as ne
56
- from langchain.tools import WikipediaQueryRun, BaseTool
57
- from langchain_community.utilities import WikipediaAPIWrapper
58
-
59
- wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=2500))
60
-
61
- print(wikipedia.run("Deep Learning"))
62
-
63
-
64
- wikipedia_tool = Tool(
65
- name="wikipedia",
66
- description="Never search for more than one concept at a single step. If you need to compare two concepts, search for each one individually. Syntax: string with a simple concept",
67
- func=wikipedia.run
68
- )
69
-
70
- class Calculator(BaseTool):
71
- name = "calculator"
72
- description = "Use this tool for math operations. It requires numexpr syntax. Use it always you need to solve any math operation. Be sure syntax is correct."
73
-
74
- def _run(self, expression: str):
75
- try:
76
- return ne.evaluate(expression).item()
77
- except Exception:
78
- return "This is not a numexpr valid syntax. Try a different syntax."
79
-
80
- def _arun(self, radius: int):
81
- raise NotImplementedError("This tool does not support async")
82
-
83
- calculator_tool = Calculator()
84
-
85
- calculator_tool.run("2+3")