|
from abc import ABC, abstractmethod |
|
import ollama |
|
from pydantic import BaseModel |
|
from pathlib import Path |
|
from google import genai |
|
from google.genai import types |
|
from mistralai import Mistral |
|
from groq import Groq |
|
from src.manager.utils.streamlit_interface import output_assistant_response |
|
|
|
|
|
class AbstractModelManager(ABC): |
|
def __init__(self, model_name, system_prompt_file="system.prompt"): |
|
self.model_name = model_name |
|
script_dir = Path(__file__).parent |
|
self.system_prompt_file = script_dir / system_prompt_file |
|
|
|
@abstractmethod |
|
def is_model_loaded(self, model): |
|
pass |
|
|
|
@abstractmethod |
|
def create_model(self, base_model, context_window=4096, temperature=0): |
|
pass |
|
|
|
@abstractmethod |
|
def request(self, prompt): |
|
pass |
|
|
|
@abstractmethod |
|
def delete(self): |
|
pass |
|
|
|
class OllamaModelManager(AbstractModelManager): |
|
def is_model_loaded(self, model): |
|
loaded_models = [m.model for m in ollama.list().models] |
|
return model in loaded_models or f'{model}:latest' in loaded_models |
|
|
|
def create_model(self, base_model, context_window=4096, temperature=0): |
|
with open(self.system_prompt_file, 'r') as f: |
|
system = f.read() |
|
|
|
if not self.is_model_loaded(self.model_name): |
|
output_assistant_response(f"Creating model {self.model_name}") |
|
ollama.create( |
|
model=self.model_name, |
|
from_=base_model, |
|
system=system, |
|
parameters={ |
|
"num_ctx": context_window, |
|
"temperature": temperature |
|
} |
|
) |
|
|
|
def request(self, prompt): |
|
response = ollama.chat( |
|
model=self.model_name, |
|
messages=[{"role": "user", "content": prompt}], |
|
) |
|
response = response['message']['content'] |
|
return response |
|
|
|
def delete(self): |
|
if self.is_model_loaded("C2Rust:latest"): |
|
output_assistant_response(f"Deleting model {self.model_name}") |
|
ollama.delete("C2Rust:latest") |
|
else: |
|
output_assistant_response(f"Model {self.model_name} not found, skipping deletion.") |
|
|
|
class GeminiModelManager(AbstractModelManager): |
|
def __init__(self, api_key): |
|
super().__init__() |
|
self.client = genai.Client(api_key=api_key) |
|
self.model = "gemini-2.0-flash" |
|
|
|
with open(self.system_prompt_file, 'r') as f: |
|
self.system_instruction = f.read() |
|
|
|
|
|
def is_model_loaded(self, model): |
|
|
|
return model == self.model |
|
|
|
def create_model(self, base_model=None, context_window=4096, temperature=0): |
|
|
|
self.model = base_model if base_model else "gemini-2.0-flash" |
|
|
|
def request(self, prompt, temperature=0, context_window=4096): |
|
|
|
response = self.client.models.generate_content( |
|
model=self.model, |
|
contents=prompt, |
|
config=types.GenerateContentConfig( |
|
temperature=temperature, |
|
max_output_tokens=context_window, |
|
system_instruction=self.system_instruction, |
|
) |
|
) |
|
return response.text |
|
|
|
def delete(self): |
|
|
|
self.model = None |
|
|
|
class MistralModelManager(AbstractModelManager): |
|
def __init__(self, api_key, model_name="mistral-small-latest", system_prompt_file="system.prompt"): |
|
super().__init__() |
|
self.client = Mistral(api_key=api_key) |
|
self.model = model_name |
|
|
|
with open(self.system_prompt_file, 'r') as f: |
|
self.system_instruction = f.read() |
|
|
|
def is_model_loaded(self, model): |
|
|
|
return model == self.model |
|
|
|
def create_model(self, base_model=None, context_window=4096, temperature=0): |
|
|
|
self.model = base_model if base_model else "mistral-small-latest" |
|
|
|
def request(self, prompt, temperature=0, context_window=4096): |
|
|
|
response = self.client.chat.complete( |
|
messages=[ |
|
{ |
|
"role":"user", |
|
"content": self.system_instruction + "\n" + prompt, |
|
} |
|
], |
|
model=self.model, |
|
temperature=temperature, |
|
max_tokens=context_window, |
|
) |
|
return response.text |
|
|
|
def delete(self): |
|
|
|
self.model = None |
|
|
|
class GroqModelManager(AbstractModelManager): |
|
def __init__(self, api_key, model_name="llama-3.3-70b-versatile", system_prompt_file="system.prompt"): |
|
super().__init__(model_name, system_prompt_file) |
|
self.client = Groq(api_key=api_key) |
|
|
|
def is_model_loaded(self, model): |
|
|
|
return model == self.model_name |
|
|
|
def create_model(self, base_model=None, context_window=4096, temperature=0): |
|
|
|
if not self.is_model_loaded(self.model_name): |
|
output_assistant_response(f"Model {self.model_name} is not available on Groq.") |
|
|
|
def request(self, prompt, temperature=0, context_window=4096): |
|
|
|
with open(self.system_prompt_file, 'r') as f: |
|
system_instruction = f.read() |
|
|
|
|
|
messages = [ |
|
{"role": "system", "content": system_instruction}, |
|
{"role": "user", "content": prompt} |
|
] |
|
|
|
|
|
response = self.client.chat.completions.create( |
|
messages=messages, |
|
model=self.model_name, |
|
temperature=temperature |
|
) |
|
|
|
|
|
return response.choices[0].message.content |
|
|
|
def delete(self): |
|
|
|
output_assistant_response(f"Deletion not supported for Groq model {self.model_name}.") |
|
|