Spaces:
Running
Running
from abc import ABC, abstractmethod | |
import ollama | |
from pydantic import BaseModel | |
from pathlib import Path | |
from google import genai | |
from google.genai import types | |
from mistralai import Mistral | |
from groq import Groq | |
from src.manager.utils.streamlit_interface import output_assistant_response | |
class AbstractModelManager(ABC): | |
def __init__(self, model_name, system_prompt_file="system.prompt"): | |
self.model_name = model_name | |
script_dir = Path(__file__).parent | |
self.system_prompt_file = script_dir / system_prompt_file | |
def is_model_loaded(self, model): | |
pass | |
def create_model(self, base_model, context_window=4096, temperature=0): | |
pass | |
def request(self, prompt): | |
pass | |
def delete(self): | |
pass | |
class OllamaModelManager(AbstractModelManager): | |
def is_model_loaded(self, model): | |
loaded_models = [m.model for m in ollama.list().models] | |
return model in loaded_models or f'{model}:latest' in loaded_models | |
def create_model(self, base_model, context_window=4096, temperature=0): | |
with open(self.system_prompt_file, 'r') as f: | |
system = f.read() | |
if not self.is_model_loaded(self.model_name): | |
output_assistant_response(f"Creating model {self.model_name}") | |
ollama.create( | |
model=self.model_name, | |
from_=base_model, | |
system=system, | |
parameters={ | |
"num_ctx": context_window, | |
"temperature": temperature | |
} | |
) | |
def request(self, prompt): | |
response = ollama.chat( | |
model=self.model_name, | |
messages=[{"role": "user", "content": prompt}], | |
) | |
response = response['message']['content'] | |
return response | |
def delete(self): | |
if self.is_model_loaded("C2Rust:latest"): | |
output_assistant_response(f"Deleting model {self.model_name}") | |
ollama.delete("C2Rust:latest") | |
else: | |
output_assistant_response(f"Model {self.model_name} not found, skipping deletion.") | |
class GeminiModelManager(AbstractModelManager): | |
def __init__(self, api_key): | |
super().__init__() | |
self.client = genai.Client(api_key=api_key) | |
self.model = "gemini-2.0-flash" | |
# read system prompt from file | |
with open(self.system_prompt_file, 'r') as f: | |
self.system_instruction = f.read() | |
def is_model_loaded(self, model): | |
# Check if the specified model is the one set in the manager | |
return model == self.model | |
def create_model(self, base_model=None, context_window=4096, temperature=0): | |
# Initialize the Gemini model settings (if applicable) | |
self.model = base_model if base_model else "gemini-2.0-flash" | |
def request(self, prompt, temperature=0, context_window=4096): | |
# Request response from the Gemini model | |
response = self.client.models.generate_content( | |
model=self.model, | |
contents=prompt, | |
config=types.GenerateContentConfig( | |
temperature=temperature, | |
max_output_tokens=context_window, | |
system_instruction=self.system_instruction, | |
) | |
) | |
return response.text | |
def delete(self): | |
# Implement model deletion logic (if applicable) | |
self.model = None | |
class MistralModelManager(AbstractModelManager): | |
def __init__(self, api_key, model_name="mistral-small-latest", system_prompt_file="system.prompt"): | |
super().__init__() | |
self.client = Mistral(api_key=api_key) | |
self.model = model_name | |
# read system prompt from file | |
with open(self.system_prompt_file, 'r') as f: | |
self.system_instruction = f.read() | |
def is_model_loaded(self, model): | |
# Check if the specified model is the one set in the manager | |
return model == self.model | |
def create_model(self, base_model=None, context_window=4096, temperature=0): | |
# Initialize the Mistral model settings (if applicable) | |
self.model = base_model if base_model else "mistral-small-latest" | |
def request(self, prompt, temperature=0, context_window=4096): | |
# Request response from the Mistral model | |
response = self.client.chat.complete( | |
messages=[ | |
{ | |
"role":"user", | |
"content": self.system_instruction + "\n" + prompt, | |
} | |
], | |
model=self.model, | |
temperature=temperature, | |
max_tokens=context_window, | |
) | |
return response.text | |
def delete(self): | |
# Implement model deletion logic (if applicable) | |
self.model = None | |
class GroqModelManager(AbstractModelManager): | |
def __init__(self, api_key, model_name="llama-3.3-70b-versatile", system_prompt_file="system.prompt"): | |
super().__init__(model_name, system_prompt_file) | |
self.client = Groq(api_key=api_key) | |
def is_model_loaded(self, model): | |
# Groq models are referenced by name; assume always available if name matches | |
return model == self.model_name | |
def create_model(self, base_model=None, context_window=4096, temperature=0): | |
# Groq does not require explicit creation; no-op | |
if not self.is_model_loaded(self.model_name): | |
output_assistant_response(f"Model {self.model_name} is not available on Groq.") | |
def request(self, prompt, temperature=0, context_window=4096): | |
# Read system instruction | |
with open(self.system_prompt_file, 'r') as f: | |
system_instruction = f.read() | |
# Build messages | |
messages = [ | |
{"role": "system", "content": system_instruction}, | |
{"role": "user", "content": prompt} | |
] | |
# Send request | |
response = self.client.chat.completions.create( | |
messages=messages, | |
model=self.model_name, | |
temperature=temperature | |
) | |
# Extract and return content | |
return response.choices[0].message.content | |
def delete(self): | |
# No deletion support for Groq-managed models | |
output_assistant_response(f"Deletion not supported for Groq model {self.model_name}.") | |