Kunal Pai commited on
Commit
ffe6e74
·
1 Parent(s): 2526988

Implement model managers for Ollama, Gemini, and Mistral; update requirements.txt with new dependencies

Browse files
Files changed (2) hide show
  1. models/llm_models.py +137 -0
  2. requirements.txt +22 -1
models/llm_models.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import ollama
3
+ from pydantic import BaseModel
4
+ from pathlib import Path
5
+ from google import genai
6
+ from google.genai import types
7
+ from mistralai import Mistral
8
+
9
+
10
+ class AbstractModelManager(ABC):
11
+ def __init__(self, model_name, system_prompt_file="system.prompt"):
12
+ self.model_name = model_name
13
+ script_dir = Path(__file__).parent
14
+ self.system_prompt_file = script_dir / system_prompt_file
15
+
16
+ @abstractmethod
17
+ def is_model_loaded(self, model):
18
+ pass
19
+
20
+ @abstractmethod
21
+ def create_model(self, base_model, context_window=4096, temperature=0):
22
+ pass
23
+
24
+ @abstractmethod
25
+ def request(self, prompt):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def delete(self):
30
+ pass
31
+
32
+ class OllamaModelManager(AbstractModelManager):
33
+ def is_model_loaded(self, model):
34
+ loaded_models = [m.model for m in ollama.list().models]
35
+ return model in loaded_models or f'{model}:latest' in loaded_models
36
+
37
+ def create_model(self, base_model, context_window=4096, temperature=0):
38
+ with open(self.system_prompt_file, 'r') as f:
39
+ system = f.read()
40
+
41
+ if not self.is_model_loaded(self.model_name):
42
+ print(f"Creating model {self.model_name}")
43
+ ollama.create(
44
+ model=self.model_name,
45
+ from_=base_model,
46
+ system=system,
47
+ parameters={
48
+ "num_ctx": context_window,
49
+ "temperature": temperature
50
+ }
51
+ )
52
+
53
+ def request(self, prompt):
54
+ response = ollama.chat(
55
+ model=self.model_name,
56
+ messages=[{"role": "user", "content": prompt}],
57
+ )
58
+ response = response['message']['content']
59
+ return response
60
+
61
+ def delete(self):
62
+ if self.is_model_loaded("C2Rust:latest"):
63
+ print(f"Deleting model {self.model_name}")
64
+ ollama.delete("C2Rust:latest")
65
+ else:
66
+ print(f"Model {self.model_name} not found, skipping deletion.")
67
+
68
+ class GeminiModelManager(AbstractModelManager):
69
+ def __init__(self, api_key):
70
+ super().__init__()
71
+ self.client = genai.Client(api_key=api_key)
72
+ self.model = "gemini-2.0-flash"
73
+ # read system prompt from file
74
+ with open(self.system_prompt_file, 'r') as f:
75
+ self.system_instruction = f.read()
76
+
77
+
78
+ def is_model_loaded(self, model):
79
+ # Check if the specified model is the one set in the manager
80
+ return model == self.model
81
+
82
+ def create_model(self, base_model=None, context_window=4096, temperature=0):
83
+ # Initialize the Gemini model settings (if applicable)
84
+ self.model = base_model if base_model else "gemini-2.0-flash"
85
+
86
+ def request(self, prompt, temperature=0, context_window=4096):
87
+ # Request response from the Gemini model
88
+ response = self.client.models.generate_content(
89
+ model=self.model,
90
+ contents=prompt,
91
+ config=types.GenerateContentConfig(
92
+ temperature=temperature,
93
+ max_output_tokens=context_window,
94
+ system_instruction=self.system_instruction,
95
+ )
96
+ )
97
+ return response.text
98
+
99
+ def delete(self):
100
+ # Implement model deletion logic (if applicable)
101
+ self.model = None
102
+
103
+ class MistralModelManager(AbstractModelManager):
104
+ def __init__(self, api_key, model_name="mistral-small-latest", system_prompt_file="system.prompt"):
105
+ super().__init__()
106
+ self.client = Mistral(api_key=api_key)
107
+ self.model = model_name
108
+ # read system prompt from file
109
+ with open(self.system_prompt_file, 'r') as f:
110
+ self.system_instruction = f.read()
111
+
112
+ def is_model_loaded(self, model):
113
+ # Check if the specified model is the one set in the manager
114
+ return model == self.model
115
+
116
+ def create_model(self, base_model=None, context_window=4096, temperature=0):
117
+ # Initialize the Mistral model settings (if applicable)
118
+ self.model = base_model if base_model else "mistral-small-latest"
119
+
120
+ def request(self, prompt, temperature=0, context_window=4096):
121
+ # Request response from the Mistral model
122
+ response = self.client.chat.complete(
123
+ messages=[
124
+ {
125
+ "role":"user",
126
+ "content": self.system_instruction + "\n" + prompt,
127
+ }
128
+ ],
129
+ model=self.model,
130
+ temperature=temperature,
131
+ max_tokens=context_window,
132
+ )
133
+ return response.text
134
+
135
+ def delete(self):
136
+ # Implement model deletion logic (if applicable)
137
+ self.model = None
requirements.txt CHANGED
@@ -1,19 +1,40 @@
1
  annotated-types==0.7.0
2
  anyio==4.9.0
3
  beautifulsoup4==4.13.3
 
4
  certifi==2025.1.31
5
  charset-normalizer==3.4.1
6
- googlesearch-python==1.3.0
 
 
 
 
 
 
 
 
7
  h11==0.14.0
8
  httpcore==1.0.7
 
9
  httpx==0.28.1
10
  idna==3.10
11
  ollama==0.4.7
 
 
 
 
 
12
  pydantic==2.11.1
13
  pydantic_core==2.33.0
 
 
14
  requests==2.32.3
 
15
  sniffio==1.3.1
16
  soupsieve==2.6
 
17
  typing-inspection==0.4.0
18
  typing_extensions==4.13.0
 
19
  urllib3==2.3.0
 
 
1
  annotated-types==0.7.0
2
  anyio==4.9.0
3
  beautifulsoup4==4.13.3
4
+ cachetools==5.5.2
5
  certifi==2025.1.31
6
  charset-normalizer==3.4.1
7
+ google-ai-generativelanguage==0.6.15
8
+ google-api-core==2.24.2
9
+ google-api-python-client==2.166.0
10
+ google-auth==2.38.0
11
+ google-auth-httplib2==0.2.0
12
+ google-genai==1.9.0
13
+ googleapis-common-protos==1.69.2
14
+ grpcio==1.71.0
15
+ grpcio-status==1.71.0
16
  h11==0.14.0
17
  httpcore==1.0.7
18
+ httplib2==0.22.0
19
  httpx==0.28.1
20
  idna==3.10
21
  ollama==0.4.7
22
+ pathlib==1.0.1
23
+ proto-plus==1.26.1
24
+ protobuf==5.29.4
25
+ pyasn1==0.6.1
26
+ pyasn1_modules==0.4.2
27
  pydantic==2.11.1
28
  pydantic_core==2.33.0
29
+ pyparsing==3.2.3
30
+ python-dotenv==1.1.0
31
  requests==2.32.3
32
+ rsa==4.9
33
  sniffio==1.3.1
34
  soupsieve==2.6
35
+ tqdm==4.67.1
36
  typing-inspection==0.4.0
37
  typing_extensions==4.13.0
38
+ uritemplate==4.1.1
39
  urllib3==2.3.0
40
+ websockets==15.0.1