Source code for perseus.loc

"""Local Ollama client for ESML.

Pure ``httpx``-based client for a local Ollama instance running at
``localhost:11434``.  Provides model management (pull, list, remove),
chat, and streaming — no external deps beyond ``httpx``.

This module backs the ``ollama`` provider slot in :mod:`esml.llm`.

Environment Variables
---------------------
OLLAMA_BASE_URL : str
    Override the Ollama endpoint.  Default: ``http://localhost:11434``.
ESML_OLLAMA_MODEL : str
    Override the default local model.  Default: ``gemma4:e2b``.
"""

from __future__ import annotations

import json
import logging
import os
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any

import httpx

logger = logging.getLogger(__name__)

_DEFAULT_BASE_URL = "http://localhost:11434"
_DEFAULT_MODEL = ""  # Auto-detected: prefers perseus:*, then first available
_REQUEST_TIMEOUT = 300.0
_PULL_TIMEOUT = 600.0  # model downloads can be large


[docs] @dataclass class ModelInfo: """Metadata for a locally available Ollama model.""" name: str size: int = 0 # bytes parameter_size: str = "" family: str = "" quantization: str = "" modified_at: str = "" @property def size_gb(self) -> float: return self.size / (1024**3) if self.size else 0.0 @property def label(self) -> str: fam = self.family.capitalize() if self.family else self.name.split(":")[0] sz = self.parameter_size or f"{self.size_gb:.1f}GB" return f"{fam}:{sz}"
[docs] class LocalOllama: """Client for a local Ollama instance. Parameters ---------- base_url : str, optional Override ``OLLAMA_BASE_URL``. model : str, optional Override ``ESML_OLLAMA_MODEL``. timeout : float, optional Request timeout in seconds (default 120). Examples -------- >>> client = LocalOllama() >>> client.is_running() True >>> models = client.list_models() >>> response = client.chat("What is IPW?") """ def __init__( self, base_url: str | None = None, model: str | None = None, timeout: float = _REQUEST_TIMEOUT, ): self.base_url = (base_url or os.environ.get("OLLAMA_BASE_URL", "").strip() or _DEFAULT_BASE_URL).rstrip("/") self._model_override = model or os.environ.get("ESML_OLLAMA_MODEL", "").strip() or _DEFAULT_MODEL self.timeout = timeout self._model_detected: str | None = None @property def model(self) -> str: """Active model — auto-detected from Ollama if not explicitly set.""" if self._model_override: return self._model_override if self._model_detected is not None: return self._model_detected # Auto-detect: prefer largest perseus:*, then first available try: models = self.list_models() perseus_models = [m for m in models if m.name.startswith("perseus")] if perseus_models: perseus_models.sort(key=lambda m: m.size, reverse=True) self._model_detected = perseus_models[0].name return perseus_models[0].name if models: self._model_detected = models[0].name return models[0].name except Exception: pass self._model_detected = "" return "" # -- Health ---------------------------------------------------------------
[docs] def is_running(self, timeout: float = 2.0) -> bool: """Check if Ollama is reachable.""" try: resp = httpx.get(f"{self.base_url}/api/tags", timeout=timeout) return resp.status_code == 200 except Exception: return False
# -- Model management -----------------------------------------------------
[docs] def list_models(self) -> list[ModelInfo]: """List locally available models.""" resp = httpx.get(f"{self.base_url}/api/tags", timeout=self.timeout) resp.raise_for_status() models = [] for m in resp.json().get("models", []): details = m.get("details", {}) models.append( ModelInfo( name=m.get("name", ""), size=m.get("size", 0), parameter_size=details.get("parameter_size", ""), family=details.get("family", ""), quantization=details.get("quantization_level", ""), modified_at=m.get("modified_at", ""), ) ) return models
[docs] def model_names(self) -> list[str]: """Return just the model name strings.""" return [m.name for m in self.list_models()]
[docs] def has_model(self, name: str) -> bool: """Check if a specific model is available locally.""" return any(m.name == name or m.name.startswith(name + ":") for m in self.list_models())
[docs] def pull( self, name: str, *, stream: bool = True, timeout: float = _PULL_TIMEOUT, ) -> Iterator[dict[str, Any]] | dict[str, Any]: """Pull (download) a model. Parameters ---------- name : str Model name, e.g. ``gemma3:4b`` or ``llama3.2:3b``. stream : bool If True, yield progress dicts as they arrive. timeout : float Download timeout (default 600s). Yields ------ dict Progress updates with ``status``, ``digest``, ``total``, ``completed``. """ payload = {"name": name, "stream": stream} if stream: return self._pull_stream(name, timeout) resp = httpx.post( f"{self.base_url}/api/pull", json=payload, timeout=timeout, ) resp.raise_for_status() return resp.json()
def _pull_stream(self, name: str, timeout: float) -> Iterator[dict[str, Any]]: """Stream pull progress.""" with httpx.stream( "POST", f"{self.base_url}/api/pull", json={"name": name, "stream": True}, timeout=timeout, ) as resp: resp.raise_for_status() for line in resp.iter_lines(): if line.strip(): try: yield json.loads(line) except json.JSONDecodeError: pass
[docs] def remove(self, name: str) -> bool: """Delete a local model. Returns True on success.""" resp = httpx.delete( f"{self.base_url}/api/delete", json={"name": name}, timeout=self.timeout, ) return resp.status_code == 200
[docs] def show(self, name: str | None = None) -> dict[str, Any]: """Get model details (parameters, template, license).""" resp = httpx.post( f"{self.base_url}/api/show", json={"name": name or self.model}, timeout=self.timeout, ) resp.raise_for_status() return resp.json()
# -- Chat -----------------------------------------------------------------
[docs] def chat( self, prompt: str, *, model: str | None = None, system: str | None = None, context: list[dict[str, str]] | None = None, temperature: float = 0.1, num_predict: int = 4096, ) -> str: """Send a chat message and return the full response. Parameters ---------- prompt : str User message. model : str, optional Override the default model. system : str, optional System prompt. context : list, optional Prior messages as ``[{"role": "user", "content": "..."}, ...]``. temperature : float Sampling temperature. num_predict : int Max tokens to generate. Returns ------- str The assistant's response text. """ messages = self._build_messages(prompt, system, context) resp = httpx.post( f"{self.base_url}/api/chat", json={ "model": model or self.model, "messages": messages, "stream": False, "options": { "temperature": temperature, "num_predict": num_predict, }, }, timeout=self.timeout, ) resp.raise_for_status() msg = resp.json().get("message", {}) content = msg.get("content", "") if not content: content = msg.get("thinking", "") return content
[docs] def stream_chat( self, prompt: str, *, model: str | None = None, system: str | None = None, context: list[dict[str, str]] | None = None, temperature: float = 0.1, num_predict: int = 4096, ) -> Iterator[str]: """Stream chat response chunks. Yields ------ str Content chunks as they arrive from the model. """ messages = self._build_messages(prompt, system, context) with httpx.stream( "POST", f"{self.base_url}/api/chat", json={ "model": model or self.model, "messages": messages, "stream": True, "options": { "temperature": temperature, "num_predict": num_predict, }, }, timeout=self.timeout, ) as resp: resp.raise_for_status() for line in resp.iter_lines(): if not line.strip(): continue try: data = json.loads(line) msg = data.get("message", {}) chunk = msg.get("content", "") or msg.get("thinking", "") if chunk: yield chunk if data.get("done", False): return except json.JSONDecodeError: pass
# -- Generate (raw completion) --------------------------------------------
[docs] def generate( self, prompt: str, *, model: str | None = None, system: str | None = None, stream: bool = False, temperature: float = 0.1, num_predict: int = 4096, ) -> str | Iterator[str]: """Raw generation endpoint (non-chat). Returns full text or stream.""" payload: dict[str, Any] = { "model": model or self.model, "prompt": prompt, "stream": stream, "options": { "temperature": temperature, "num_predict": num_predict, }, } if system: payload["system"] = system if stream: return self._generate_stream(payload) resp = httpx.post( f"{self.base_url}/api/generate", json=payload, timeout=self.timeout, ) resp.raise_for_status() return resp.json().get("response", "")
def _generate_stream(self, payload: dict[str, Any]) -> Iterator[str]: with httpx.stream( "POST", f"{self.base_url}/api/generate", json=payload, timeout=self.timeout, ) as resp: resp.raise_for_status() for line in resp.iter_lines(): if not line.strip(): continue try: data = json.loads(line) chunk = data.get("response", "") if chunk: yield chunk if data.get("done", False): return except json.JSONDecodeError: pass # -- Helpers -------------------------------------------------------------- @staticmethod def _build_messages( prompt: str, system: str | None, context: list[dict[str, str]] | None, ) -> list[dict[str, str]]: messages: list[dict[str, str]] = [] if system: messages.append({"role": "system", "content": system}) if context: messages.extend(context) messages.append({"role": "user", "content": prompt}) return messages