Kunal Pai
commited on
Commit
·
ffe6e74
1
Parent(s):
2526988
Implement model managers for Ollama, Gemini, and Mistral; update requirements.txt with new dependencies
Browse files- models/llm_models.py +137 -0
- requirements.txt +22 -1
models/llm_models.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import ollama
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from pathlib import Path
|
5 |
+
from google import genai
|
6 |
+
from google.genai import types
|
7 |
+
from mistralai import Mistral
|
8 |
+
|
9 |
+
|
10 |
+
class AbstractModelManager(ABC):
|
11 |
+
def __init__(self, model_name, system_prompt_file="system.prompt"):
|
12 |
+
self.model_name = model_name
|
13 |
+
script_dir = Path(__file__).parent
|
14 |
+
self.system_prompt_file = script_dir / system_prompt_file
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def is_model_loaded(self, model):
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def create_model(self, base_model, context_window=4096, temperature=0):
|
22 |
+
pass
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def request(self, prompt):
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def delete(self):
|
30 |
+
pass
|
31 |
+
|
32 |
+
class OllamaModelManager(AbstractModelManager):
|
33 |
+
def is_model_loaded(self, model):
|
34 |
+
loaded_models = [m.model for m in ollama.list().models]
|
35 |
+
return model in loaded_models or f'{model}:latest' in loaded_models
|
36 |
+
|
37 |
+
def create_model(self, base_model, context_window=4096, temperature=0):
|
38 |
+
with open(self.system_prompt_file, 'r') as f:
|
39 |
+
system = f.read()
|
40 |
+
|
41 |
+
if not self.is_model_loaded(self.model_name):
|
42 |
+
print(f"Creating model {self.model_name}")
|
43 |
+
ollama.create(
|
44 |
+
model=self.model_name,
|
45 |
+
from_=base_model,
|
46 |
+
system=system,
|
47 |
+
parameters={
|
48 |
+
"num_ctx": context_window,
|
49 |
+
"temperature": temperature
|
50 |
+
}
|
51 |
+
)
|
52 |
+
|
53 |
+
def request(self, prompt):
|
54 |
+
response = ollama.chat(
|
55 |
+
model=self.model_name,
|
56 |
+
messages=[{"role": "user", "content": prompt}],
|
57 |
+
)
|
58 |
+
response = response['message']['content']
|
59 |
+
return response
|
60 |
+
|
61 |
+
def delete(self):
|
62 |
+
if self.is_model_loaded("C2Rust:latest"):
|
63 |
+
print(f"Deleting model {self.model_name}")
|
64 |
+
ollama.delete("C2Rust:latest")
|
65 |
+
else:
|
66 |
+
print(f"Model {self.model_name} not found, skipping deletion.")
|
67 |
+
|
68 |
+
class GeminiModelManager(AbstractModelManager):
|
69 |
+
def __init__(self, api_key):
|
70 |
+
super().__init__()
|
71 |
+
self.client = genai.Client(api_key=api_key)
|
72 |
+
self.model = "gemini-2.0-flash"
|
73 |
+
# read system prompt from file
|
74 |
+
with open(self.system_prompt_file, 'r') as f:
|
75 |
+
self.system_instruction = f.read()
|
76 |
+
|
77 |
+
|
78 |
+
def is_model_loaded(self, model):
|
79 |
+
# Check if the specified model is the one set in the manager
|
80 |
+
return model == self.model
|
81 |
+
|
82 |
+
def create_model(self, base_model=None, context_window=4096, temperature=0):
|
83 |
+
# Initialize the Gemini model settings (if applicable)
|
84 |
+
self.model = base_model if base_model else "gemini-2.0-flash"
|
85 |
+
|
86 |
+
def request(self, prompt, temperature=0, context_window=4096):
|
87 |
+
# Request response from the Gemini model
|
88 |
+
response = self.client.models.generate_content(
|
89 |
+
model=self.model,
|
90 |
+
contents=prompt,
|
91 |
+
config=types.GenerateContentConfig(
|
92 |
+
temperature=temperature,
|
93 |
+
max_output_tokens=context_window,
|
94 |
+
system_instruction=self.system_instruction,
|
95 |
+
)
|
96 |
+
)
|
97 |
+
return response.text
|
98 |
+
|
99 |
+
def delete(self):
|
100 |
+
# Implement model deletion logic (if applicable)
|
101 |
+
self.model = None
|
102 |
+
|
103 |
+
class MistralModelManager(AbstractModelManager):
|
104 |
+
def __init__(self, api_key, model_name="mistral-small-latest", system_prompt_file="system.prompt"):
|
105 |
+
super().__init__()
|
106 |
+
self.client = Mistral(api_key=api_key)
|
107 |
+
self.model = model_name
|
108 |
+
# read system prompt from file
|
109 |
+
with open(self.system_prompt_file, 'r') as f:
|
110 |
+
self.system_instruction = f.read()
|
111 |
+
|
112 |
+
def is_model_loaded(self, model):
|
113 |
+
# Check if the specified model is the one set in the manager
|
114 |
+
return model == self.model
|
115 |
+
|
116 |
+
def create_model(self, base_model=None, context_window=4096, temperature=0):
|
117 |
+
# Initialize the Mistral model settings (if applicable)
|
118 |
+
self.model = base_model if base_model else "mistral-small-latest"
|
119 |
+
|
120 |
+
def request(self, prompt, temperature=0, context_window=4096):
|
121 |
+
# Request response from the Mistral model
|
122 |
+
response = self.client.chat.complete(
|
123 |
+
messages=[
|
124 |
+
{
|
125 |
+
"role":"user",
|
126 |
+
"content": self.system_instruction + "\n" + prompt,
|
127 |
+
}
|
128 |
+
],
|
129 |
+
model=self.model,
|
130 |
+
temperature=temperature,
|
131 |
+
max_tokens=context_window,
|
132 |
+
)
|
133 |
+
return response.text
|
134 |
+
|
135 |
+
def delete(self):
|
136 |
+
# Implement model deletion logic (if applicable)
|
137 |
+
self.model = None
|
requirements.txt
CHANGED
@@ -1,19 +1,40 @@
|
|
1 |
annotated-types==0.7.0
|
2 |
anyio==4.9.0
|
3 |
beautifulsoup4==4.13.3
|
|
|
4 |
certifi==2025.1.31
|
5 |
charset-normalizer==3.4.1
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
h11==0.14.0
|
8 |
httpcore==1.0.7
|
|
|
9 |
httpx==0.28.1
|
10 |
idna==3.10
|
11 |
ollama==0.4.7
|
|
|
|
|
|
|
|
|
|
|
12 |
pydantic==2.11.1
|
13 |
pydantic_core==2.33.0
|
|
|
|
|
14 |
requests==2.32.3
|
|
|
15 |
sniffio==1.3.1
|
16 |
soupsieve==2.6
|
|
|
17 |
typing-inspection==0.4.0
|
18 |
typing_extensions==4.13.0
|
|
|
19 |
urllib3==2.3.0
|
|
|
|
1 |
annotated-types==0.7.0
|
2 |
anyio==4.9.0
|
3 |
beautifulsoup4==4.13.3
|
4 |
+
cachetools==5.5.2
|
5 |
certifi==2025.1.31
|
6 |
charset-normalizer==3.4.1
|
7 |
+
google-ai-generativelanguage==0.6.15
|
8 |
+
google-api-core==2.24.2
|
9 |
+
google-api-python-client==2.166.0
|
10 |
+
google-auth==2.38.0
|
11 |
+
google-auth-httplib2==0.2.0
|
12 |
+
google-genai==1.9.0
|
13 |
+
googleapis-common-protos==1.69.2
|
14 |
+
grpcio==1.71.0
|
15 |
+
grpcio-status==1.71.0
|
16 |
h11==0.14.0
|
17 |
httpcore==1.0.7
|
18 |
+
httplib2==0.22.0
|
19 |
httpx==0.28.1
|
20 |
idna==3.10
|
21 |
ollama==0.4.7
|
22 |
+
pathlib==1.0.1
|
23 |
+
proto-plus==1.26.1
|
24 |
+
protobuf==5.29.4
|
25 |
+
pyasn1==0.6.1
|
26 |
+
pyasn1_modules==0.4.2
|
27 |
pydantic==2.11.1
|
28 |
pydantic_core==2.33.0
|
29 |
+
pyparsing==3.2.3
|
30 |
+
python-dotenv==1.1.0
|
31 |
requests==2.32.3
|
32 |
+
rsa==4.9
|
33 |
sniffio==1.3.1
|
34 |
soupsieve==2.6
|
35 |
+
tqdm==4.67.1
|
36 |
typing-inspection==0.4.0
|
37 |
typing_extensions==4.13.0
|
38 |
+
uritemplate==4.1.1
|
39 |
urllib3==2.3.0
|
40 |
+
websockets==15.0.1
|