Source code for perseus.llm

"""Ollama-first LLM integration layer for the ESML package.

Provides a provider chain that attempts local Ollama inference first, then
OllamaFreeAPI (free remote models, no API key), then Gemini (Google), then
a generic OpenAI-compatible endpoint (e.g. Qwen via OpenRouter, GPT-OSS
models via Together/Groq), then the official OpenAI API, and finally a
local help-text fallback that requires no network access.

HTTP-based providers use ``httpx`` against OpenAI-compatible endpoints.
OllamaFreeAPI uses its own Python SDK (``ollamafreeapi``) for free remote
model access without any API key.

Environment Variables
---------------------
OLLAMA_BASE_URL : str
    Base URL for a running Ollama instance.  Default: ``http://localhost:11434``
esmlfam : str
    Override the OllamaFreeAPI model (esml free api model).  Default: ``mistral-nemo:custom``.
GEMINI_API_KEY : str
    Google AI Studio API key.  Free-tier keys work for development.
    Model defaults to ``gemini-2.0-flash``.
GEMINI_MODEL : str
    Override the Gemini model (e.g. ``gemini-1.5-pro``).  Optional.
LLM_API_BASE_URL : str
    Base URL for any OpenAI-compatible API (e.g., OpenRouter, Together, Groq).
    Use this to point at Qwen, Mistral, GPT-OSS, or any hosted model.
LLM_API_KEY : str
    API key for the endpoint at ``LLM_API_BASE_URL``.
OPENAI_API_KEY : str
    API key for the official OpenAI API at ``https://api.openai.com``.

Provider priority (auto-detected at runtime):
    1. Ollama    — local, private, no API key needed
    2. FreeAPI   — OllamaFreeAPI, free remote models, no API key
    3. Gemini    — Google AI, generous free tier
    4. API       — generic OpenAI-compatible (Qwen, GPT-OSS, Groq, etc.)
    5. OpenAI    — official OpenAI API
    6. local     — static help text, no network required

References
----------
* Ollama API docs: https://github.com/ollama/ollama/blob/main/docs/api.md
* OllamaFreeAPI: https://pypi.org/project/ollamafreeapi/
* Gemini OpenAI-compatible API: https://ai.google.dev/gemini-api/docs/openai
* OpenAI Chat Completions API: https://platform.openai.com/docs/api-reference/chat
"""

from __future__ import annotations

import json
import logging
import os
from collections.abc import Iterator
from pathlib import Path
from typing import Any

import httpx

# cpads + modules are esml-specific; perseus is standalone and falls back
# to empty context if the consuming package doesn't provide them.
try:
    from esml.cpads import cpads_contract  # type: ignore[import-not-found]
except ImportError:
    def cpads_contract() -> dict:  # type: ignore[misc]
        return {}

try:
    from esml.modules import MODULE_SPECS  # type: ignore[import-not-found]
except ImportError:
    MODULE_SPECS: dict = {}  # type: ignore[misc,no-redef]

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Default configuration
# ---------------------------------------------------------------------------

DEFAULT_OLLAMA_BASE_URL = "http://localhost:11434"
DEFAULT_OLLAMA_MODEL = ""  # Auto-detected from running Ollama instance
DEFAULT_FREEAPI_MODEL = "mistral-nemo:custom"
DEFAULT_GEMINI_MODEL = "gemini-2.0-flash"
DEFAULT_API_MODEL = "google/gemma-3-27b-it"
DEFAULT_OPENAI_MODEL = "gpt-4o-mini"

OPENAI_BASE_URL = "https://api.openai.com"
# Gemini exposes an OpenAI-compatible endpoint; no extra SDK required.
GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"

_PROVIDER_OLLAMA = "ollama"
_PROVIDER_FREEAPI = "freeapi"
_PROVIDER_GEMINI = "gemini"
_PROVIDER_API = "api"
_PROVIDER_OPENAI = "openai"
_PROVIDER_LOCAL = "local"

# Timeout for the quick health-check probe (seconds).
_PROBE_TIMEOUT = 2.0

# Timeout for actual generation requests (seconds).
_REQUEST_TIMEOUT = 120.0

# ---------------------------------------------------------------------------
# System prompt template
# ---------------------------------------------------------------------------

_ESML_SYSTEM_PROMPT_TEMPLATE = """\
You are the ESML agent for epidemiological semiparametric machine learning.

ESML is a Python+R terminal IDE for Canadian public health data analysis, \
causal inference, and reproducible research. Install: pip install esml

TUI keys: c=Chat p=Pipeline d=Doctor i=Datasets h=Help s=Stats e=REPL q=Quit
Chat commands: /run /list /doctor /profile /inspect /verify /agent /help /clear
REPL: ?question=AI !cmd=shell R>code=R. Helpers: load() head() describe() cols()
CLI: esml list-modules, esml run-module <name>, esml pipeline --all -y
CLI: esml list-datasets, esml doctor, esml selftest, esml ask "question"

32 built-in datasets: CPADS, CCS, CSADS, CSUS, HealthInfobase, CIHI.
Load: load('cpads') in REPL or from esml.data import load_dataset

48 stats commands: ttest anova chi2 corr regression pscore ipw aipw ate \
cohend kaplanmeier coxph did rddesign ivreg vif and more (press s).

21 modules: data-wrangling descriptive-statistics frequentist-inference \
bayesian-inference propensity-scores causal-estimators treatment-effects \
ebac-core figures tables final-report and more.

Debug: press d for Doctor, b for logs, esml selftest for smoke tests.

Give practical answers with specific ESML commands. Be explicit about \
assumptions and limitations.

{context_block}
"""

# ---------------------------------------------------------------------------
# Provider detection
# ---------------------------------------------------------------------------


def _ollama_base_url() -> str:
    """Return the configured Ollama base URL."""
    return os.environ.get("OLLAMA_BASE_URL", DEFAULT_OLLAMA_BASE_URL).rstrip("/")


_ollama_model_cached: str | None = None


def _ollama_model() -> str:
    """Return the Ollama model to use — auto-detected from the running instance.

    Priority:
    1. ESML_OLLAMA_MODEL env var (explicit override)
    2. First model from ``ollama list`` (auto-detect)
    3. Empty string (no model available)

    Cached for the process lifetime after first detection.
    """
    global _ollama_model_cached
    if _ollama_model_cached is not None:
        return _ollama_model_cached

    # 1. Env var override
    env = os.environ.get("ESML_OLLAMA_MODEL", "").strip()
    if env:
        _ollama_model_cached = env
        return env

    # 2. Auto-detect from running Ollama — prefer largest perseus:* model
    try:
        from .loc import LocalOllama

        client = LocalOllama()
        models = client.list_models()
        if models:
            perseus_models = [m for m in models if m.name.startswith("perseus")]
            if perseus_models:
                perseus_models.sort(key=lambda m: m.size, reverse=True)
                _ollama_model_cached = perseus_models[0].name
                return perseus_models[0].name
            _ollama_model_cached = models[0].name
            return models[0].name
    except Exception:
        pass

    _ollama_model_cached = DEFAULT_OLLAMA_MODEL
    return DEFAULT_OLLAMA_MODEL


def _api_base_url() -> str | None:
    """Return the configured generic OpenAI-compatible base URL, or None."""
    url = os.environ.get("LLM_API_BASE_URL", "").strip()
    return url.rstrip("/") if url else None


def _api_key() -> str | None:
    """Return the LLM_API_KEY for the generic endpoint."""
    return os.environ.get("LLM_API_KEY", "").strip() or None


def _openai_key() -> str | None:
    """Return the OPENAI_API_KEY."""
    return os.environ.get("OPENAI_API_KEY", "").strip() or None


def _gemini_key() -> str | None:
    """Return the GEMINI_API_KEY for Google AI Studio."""
    return os.environ.get("GEMINI_API_KEY", "").strip() or None


def _gemini_model() -> str:
    """Return the configured Gemini model name."""
    return os.environ.get("GEMINI_MODEL", DEFAULT_GEMINI_MODEL).strip()


_ollama_cached: bool | None = None


def _probe_ollama(timeout: float = _PROBE_TIMEOUT) -> bool:
    """Return True if a local Ollama instance responds to a health check.

    The result is cached for the process lifetime to avoid repeated 2-second
    network timeouts on every call to :func:`detect_available_provider`.

    Uses :class:`esml.loc.LocalOllama` for the probe.

    Parameters
    ----------
    timeout : float
        Maximum seconds to wait for the Ollama ``/api/tags`` endpoint.

    Returns
    -------
    bool
        ``True`` when Ollama is reachable, ``False`` otherwise.
    """
    global _ollama_cached
    if _ollama_cached is not None:
        return _ollama_cached
    try:
        from .loc import LocalOllama

        client = LocalOllama(base_url=_ollama_base_url())
        _ollama_cached = client.is_running(timeout=timeout)
    except Exception:
        _ollama_cached = False
    return _ollama_cached


_freeapi_cached: bool | None = None


def _probe_freeapi() -> bool:
    """Return True if ``ollamafreeapi`` is importable and has available models.

    The result is cached for the process lifetime to avoid repeated network
    probes on every call to :func:`detect_available_provider`.
    """
    global _freeapi_cached
    if _freeapi_cached is not None:
        return _freeapi_cached
    try:
        from .fam import OllamaFreeAPI

        client = OllamaFreeAPI()
        models = client.list_models()
        _freeapi_cached = bool(models)
    except Exception:
        # Retry once — community servers can be slow to respond
        try:
            import time

            time.sleep(1)
            from .fam import OllamaFreeAPI

            client = OllamaFreeAPI()
            models = client.list_models()
            _freeapi_cached = bool(models)
        except Exception:
            _freeapi_cached = False
    return _freeapi_cached


def _freeapi_model() -> str:
    """Return the configured OllamaFreeAPI model name."""
    return os.environ.get("esmlfam", DEFAULT_FREEAPI_MODEL).strip()


[docs] def detect_available_provider() -> str: """Detect which LLM provider is currently available. The detection order mirrors the provider chain priority: 1. **ollama** -- a local Ollama instance is reachable (probed via HTTP). 2. **freeapi** -- ``ollamafreeapi`` package is installed and servers respond. 3. **gemini** -- ``GEMINI_API_KEY`` is set. 4. **api** -- ``LLM_API_BASE_URL`` and ``LLM_API_KEY`` are set. 5. **openai** -- ``OPENAI_API_KEY`` is set. 6. **local** -- no live provider; ESML will return static help text. Returns ------- str One of ``"ollama"``, ``"freeapi"``, ``"gemini"``, ``"api"``, ``"openai"``, or ``"local"``. Examples -------- >>> provider = detect_available_provider() >>> provider in ("ollama", "freeapi", "gemini", "api", "openai", "local") True """ if _probe_ollama(): return _PROVIDER_OLLAMA if _probe_freeapi(): return _PROVIDER_FREEAPI if _gemini_key(): return _PROVIDER_GEMINI if _api_base_url() and _api_key(): return _PROVIDER_API if _openai_key(): return _PROVIDER_OPENAI return _PROVIDER_LOCAL
# -- Dynamic model labeling and alias table --------------------------------- def _normalize_size(size: str) -> str: """Normalize '4.3B' → '4.3b', '134.52M' → '135m'.""" s = size.strip().lower() if s.endswith("m"): try: return f"{round(float(s[:-1]))}m" except ValueError: pass return s def _model_display_label(model_name: str, family: str = "", size: str = "") -> str: """Build display label like 'Gemma3:4.3b' from model metadata.""" if family and size: return f"{family.capitalize()}:{_normalize_size(size)}" return model_name
[docs] def list_freeapi_models() -> list[dict[str, str]]: """List all available OllamaFreeAPI models from vendored JSONs. Returns ------- list[dict[str, str]] Each dict has keys: ``model``, ``family``, ``size``, ``label``, ``alias``. """ import json json_dir = Path(__file__).parent / "ollama_json" seen: set[str] = set() models: list[dict[str, str]] = [] for jf in sorted(json_dir.glob("*.json")): try: data = json.loads(jf.read_text()) for m in data.get("props", {}).get("pageProps", {}).get("models", []): name = m.get("model_name") or m.get("model", "") if name and name not in seen: seen.add(name) family = m.get("family", "") size = m.get("parameter_size", "") models.append( { "model": name, "family": family, "size": size, "label": _model_display_label(name, family, size), } ) except Exception: pass # Assign 2-letter aliases (first letter of base name + first of family) used_aliases: set[str] = set() for m in models: base = m["model"].split(":")[0].split("/")[-1] # e.g. "gpt-oss" or "deepseek-r1" fam = m["family"] or base # Try: first of base + first of family alias = (base[0] + fam[0]).lower() if alias in used_aliases: # Try first two of base alias = base[:2].lower() if alias in used_aliases: # Try first of base + size digit sz = m["size"] digit = "".join(c for c in sz if c.isdigit())[:1] or "x" alias = (base[0] + digit).lower() if alias in used_aliases: alias = base[:3].lower() used_aliases.add(alias) m["alias"] = alias return models
def _build_alias_table() -> dict[str, str]: """Build alias → model_name mapping from vendored JSONs.""" return {m["alias"]: m["model"] for m in list_freeapi_models()}
[docs] def detect_provider_and_model() -> tuple[str, str]: """Detect LLM provider and return (provider, human-readable model label). Returns ------- tuple[str, str] ``(provider_key, display_label)`` — e.g. ``("freeapi", "Gemma3:4.3b")``. """ provider = detect_available_provider() if provider == _PROVIDER_FREEAPI: model = _freeapi_model() # Find metadata from JSONs for the label for m in list_freeapi_models(): if m["model"] == model: return provider, m["label"] return provider, model if provider == _PROVIDER_OLLAMA: model = _ollama_model() return provider, f"Ollama:{model}" if provider == _PROVIDER_GEMINI: model = os.environ.get("ESML_GEMINI_MODEL", DEFAULT_GEMINI_MODEL).strip() return provider, f"Gemini:{model}" if provider == _PROVIDER_API: model = os.environ.get("ESML_API_MODEL", DEFAULT_API_MODEL).strip() return provider, f"API:{model}" if provider == _PROVIDER_OPENAI: model = os.environ.get("ESML_OPENAI_MODEL", DEFAULT_OPENAI_MODEL).strip() return provider, f"OpenAI:{model}" return provider, "local fallback (no LLM)"
[docs] def detect_model_display() -> dict[str, str]: """Return display info with inner (family:size) and outer (model name). Returns ------- dict[str, str] Keys: ``inner``, ``outer``, ``model``, ``provider``. HomeScreen format: ``LLM: {inner} [{outer}]`` """ provider = detect_available_provider() if provider == _PROVIDER_FREEAPI: model = _freeapi_model() for m in list_freeapi_models(): if m["model"] == model: outer = m["model"].upper().split("/")[-1] return { "inner": m["label"].upper(), "outer": outer, "model": m["model"], "provider": provider, } return {"inner": model.upper(), "outer": model.upper(), "model": model, "provider": provider} if provider == _PROVIDER_OLLAMA: model = _ollama_model() return {"inner": "OLLAMA", "outer": model.upper(), "model": model, "provider": provider} if provider == _PROVIDER_GEMINI: model = os.environ.get("ESML_GEMINI_MODEL", DEFAULT_GEMINI_MODEL).strip() return {"inner": "GEMINI", "outer": model.upper(), "model": model, "provider": provider} return {"inner": "LOCAL", "outer": "FALLBACK", "model": "", "provider": provider}
# -- Thinking word synonyms ------------------------------------------------- _THINK_WORDS = [ "synthesizing", "parsing", "vectorizing", "optimizing", "brewing", "ruminating", "pondering", "wrangling pixels", "consulting the scrolls", "envisioning", "distilling", "weaving", "crystallizing", "tracing", "fluxing", "modulating", "sequencing", "combobulating", "calibrating", "interpolating", "decomposing", "iterating", "compiling gradients", "tuning hyperparameters", "aligning embeddings", ] _CONTEXT_WORDS: dict[str, str] = { "monte carlo": "simulating Monte Carlo", "markov": "simulating Markov Chains", "counterfactual": "estimating counterfactuals", "propensity": "scoring propensities", "bootstrap": "bootstrapping", "regression": "fitting regression surfaces", "bayesian": "sampling posteriors", "causal": "tracing causal paths", "survival": "modeling survival curves", "genomic": "sequencing loci", "epigenetic": "mapping methylation", "sample": "drawing samples", "hypothesis": "testing hypotheses", "dml": "cross-fitting folds", "forest": "growing random forests", "neural": "propagating activations", "cluster": "partitioning clusters", "pca": "reducing dimensions", "variance": "decomposing variance", "likelihood": "maximizing likelihood", "posterior": "sampling posteriors", "prior": "eliciting priors", "iv": "instrumenting variables", "matching": "pairing counterfactuals", "weight": "calibrating weights", "treatment": "estimating treatment effects", "power": "computing power curves", "odds": "computing odds ratios", "hazard": "modeling hazard rates", "genome": "scanning the genome", "methylation": "mapping CpG islands", "gene": "annotating gene variants", "protein": "folding protein structures", "cell": "profiling cell types", "drug": "screening compounds", "trial": "designing trial arms", "randomiz": "allocating treatment arms", "stratif": "stratifying strata", "confound": "adjusting for confounders", "bias": "diagnosing bias sources", "missing": "imputing missing values", "outlier": "flagging outliers", "time series": "forecasting trajectories", "spatial": "mapping spatial fields", "network": "tracing network edges", "graph": "traversing graph paths", "entropy": "measuring information entropy", "game theory": "solving equilibria", "nash": "finding Nash equilibria", "mechanism": "designing mechanisms", "auction": "simulating auctions", }
[docs] def pick_thinking_word(query: str) -> str: """Pick a context-aware thinking word based on the query, or a random one.""" import random q = query.lower() # Check context keywords first for kw, phrase in _CONTEXT_WORDS.items(): if kw in q: return phrase return random.choice(_THINK_WORDS)
# --------------------------------------------------------------------------- # Context building # ---------------------------------------------------------------------------
[docs] def build_esml_context(repo_root: str | Path | None = None) -> dict[str, Any]: """Build an LLM-friendly context dictionary from the ESML package state. The returned dictionary is designed to be injected into the system prompt so the LLM is aware of the available modules, the CPADS data contract, and the current working directory. Parameters ---------- repo_root : str | Path | None Path to the ESML repository root. When ``None`` the function attempts to resolve the root from this file's location. Returns ------- dict[str, Any] A dictionary with keys: - ``module_list`` -- list of module name/description pairs. - ``cpads_schema`` -- the CPADS data contract dictionary. - ``cwd`` -- the current working directory as a string. - ``repo_root`` -- the resolved repository root, or ``"unknown"``. Examples -------- >>> ctx = build_esml_context() >>> "module_list" in ctx and "cpads_schema" in ctx True """ if repo_root is None: # Attempt to resolve from file location: llm.py -> esml/ -> py-package/ -> esml-root/ try: repo_root = str(Path(__file__).resolve().parents[2]) except Exception: repo_root = "unknown" else: repo_root = str(Path(repo_root).resolve()) module_list = [{"name": spec.name, "description": spec.description} for spec in MODULE_SPECS.values()] return { "module_list": module_list, "cpads_schema": cpads_contract(), "cwd": os.getcwd(), "repo_root": repo_root, "function_signatures": _collect_function_signatures(), "dataset_schema": _current_dataset_schema(), "stat_commands": _stat_command_summary(), }
_ESML_MODULES = [ "esml.quant", "esml.causal", "esml.effects", "esml.survey", "esml.inference", "esml.did", "esml.rdd", "esml.iv", "esml.matching", "esml.survival", "esml.sensitivity", "esml.ml", "esml.ebac", "esml.sampling", "esml.loc", "esml.emissions", "esml.data", "esml.modules", ]
[docs] def get_last_traceback() -> str: """Return the last Python traceback, if any, for error-context injection.""" import sys import traceback exc = sys.last_value if hasattr(sys, "last_value") else None if exc is None: return "" tb = getattr(sys, "last_traceback", None) if tb is None: return f"{type(exc).__name__}: {exc}" lines = traceback.format_exception(type(exc), exc, tb) text = "".join(lines) return text[-1000:] if len(text) > 1000 else text
def _retrieve_relevant_source(query: str, max_chars: int = 1500) -> str: """RAG: retrieve actual source code relevant to the user's query. Searches esml module functions for keyword matches against the query, then returns the full docstring + signature of matching functions. This gives the LLM actual code to reference instead of hallucinating. Parameters ---------- query : str The user's question. max_chars : int Max characters of source to inject. Returns ------- str Relevant source code snippets, or empty string. """ import importlib import inspect import re # Extract keywords from query (lowercase, strip punctuation) keywords = set(re.findall(r"[a-z_]{3,}", query.lower())) # Add common synonyms if "ipw" in keywords: keywords.update({"propensity", "weight", "inverse"}) if "ate" in keywords: keywords.update({"treatment", "effect", "estimate"}) if "quant" in keywords or "turboquant" in keywords: keywords.update({"quantize", "codebook", "rotation", "turboquant"}) if "qjl" in keywords: keywords.update({"sign", "projection", "residual", "encode", "decode"}) if "load" in keywords or "cpads" in keywords or "dataset" in keywords: keywords.update({"load_dataset", "dataset", "cpads", "load"}) if "dml" in keywords: keywords.update({"double", "machine", "plr", "estimate"}) matches: list[tuple[float, str]] = [] for mod_name in _ESML_MODULES: try: mod = importlib.import_module(mod_name) except ImportError: continue for name in dir(mod): if name.startswith("_"): continue obj = getattr(mod, name, None) if obj is None or not callable(obj): continue obj_mod = getattr(obj, "__module__", "") if not obj_mod.startswith("esml"): continue # Score by keyword overlap with function name + docstring fn_lower = name.lower() doc = (getattr(obj, "__doc__", "") or "").lower() score = 0.0 for kw in keywords: if kw in fn_lower: score += 3.0 # strong match on function name if kw in doc[:200]: score += 1.0 # match in docstring if score > 0: try: sig = inspect.signature(obj) full_doc = getattr(obj, "__doc__", "") or "" snippet = f"{mod_name.split('.')[-1]}.{name}{sig}\n{full_doc}" matches.append((score, snippet)) except (ValueError, TypeError): pass if not matches: return "" # Sort by relevance score, take top matches matches.sort(key=lambda x: x[0], reverse=True) result_parts: list[str] = [] total = 0 for _score, snippet in matches: if total + len(snippet) > max_chars: break result_parts.append(snippet) total += len(snippet) return "\n---\n".join(result_parts) if result_parts else "" def _collect_function_signatures() -> list[dict[str, str]]: """Introspect key esml modules for function names + one-line descriptions.""" import importlib sigs: list[dict[str, str]] = [] for mod_name in _ESML_MODULES[:14]: try: mod = importlib.import_module(mod_name) for name in sorted(dir(mod)): if name.startswith("_"): continue obj = getattr(mod, name, None) if obj is None or not callable(obj): continue if not hasattr(obj, "__doc__") or not obj.__doc__: continue # Skip type aliases, dataclass decorators, stdlib re-exports obj_mod = getattr(obj, "__module__", "") if not obj_mod.startswith("esml"): continue first_line = obj.__doc__.strip().split("\n")[0][:80] sigs.append({"fn": f"{mod_name.split('.')[-1]}.{name}", "desc": first_line}) except ImportError: continue return sigs[:60] def _current_dataset_schema() -> dict[str, str] | None: """If a DataFrame is loaded in ESMLApp, return its column schema.""" try: from . import tui as _tui_mod app = getattr(_tui_mod, "_running_app", None) if app and hasattr(app, "loaded_df") and app.loaded_df is not None: df = app.loaded_df return {col: str(dtype) for col, dtype in df.dtypes.items()} except Exception: pass return None def _stat_command_summary() -> list[str]: """Return top stat command names grouped concisely.""" try: from .stat_commands import list_all_commands return list_all_commands()[:100] except Exception: try: from .stat_commands import COMMAND_REGISTRY return sorted(COMMAND_REGISTRY.keys())[:100] except Exception: return [] def _format_context_block(context: dict[str, Any] | None) -> str: """Render a context dictionary into a text block for the system prompt. Caps total output at ~2000 chars to avoid blowing the context window. """ if not context: return "" parts: list[str] = [] modules = context.get("module_list") if modules: names = ", ".join(m["name"] for m in modules) parts.append(f"Available ESML modules: {names}") schema = context.get("cpads_schema") if schema: req_vars = schema.get("required_variables", []) parts.append(f"CPADS required variables: {', '.join(req_vars)}") # Dataset schema (if a DataFrame is loaded) ds_schema = context.get("dataset_schema") if ds_schema: cols = [f"{c}({t})" for c, t in list(ds_schema.items())[:20]] parts.append(f"Loaded dataset columns: {', '.join(cols)}") # Function signatures (top ones) fn_sigs = context.get("function_signatures") if fn_sigs: sig_lines = [f"{s['fn']}: {s['desc']}" for s in fn_sigs[:25]] parts.append("Key functions:\n" + "\n".join(sig_lines)) # Stat commands stat_cmds = context.get("stat_commands") if stat_cmds: parts.append(f"Stat commands ({len(stat_cmds)}): {', '.join(stat_cmds[:40])}") cwd = context.get("cwd") if cwd: parts.append(f"User working directory: {cwd}") # RAG: inject relevant source code for the current query rag = context.get("rag_source") if rag: parts.append(f"RELEVANT SOURCE CODE (use this to answer accurately):\n{rag}") block = "\n".join(parts) if len(block) > 3500: block = block[:3500] + "\n..." return block # --------------------------------------------------------------------------- # Chat completions helpers # --------------------------------------------------------------------------- def _build_messages( prompt: str, context: dict[str, Any] | None = None, system_prompt: str | None = None, ) -> list[dict[str, str]]: """Build the ``messages`` array for the chat completions payload.""" if system_prompt is None: context_block = _format_context_block(context) system_prompt = _ESML_SYSTEM_PROMPT_TEMPLATE.format(context_block=context_block) return [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ] def _request_completion( base_url: str, model: str, messages: list[dict[str, str]], *, api_key: str | None = None, stream: bool = False, timeout: float = _REQUEST_TIMEOUT, ) -> httpx.Response: """Send a POST to ``/v1/chat/completions`` and return the raw response. Parameters ---------- base_url : str The provider base URL (e.g., ``http://localhost:11434`` for Ollama, ``https://api.openai.com`` for OpenAI). model : str The model identifier to use. messages : list[dict[str, str]] The chat messages array. api_key : str | None Bearer token. Omitted for local Ollama requests. stream : bool Whether to request server-sent-event streaming. timeout : float Request timeout in seconds. Returns ------- httpx.Response The raw ``httpx`` response object. """ url = f"{base_url}/v1/chat/completions" headers: dict[str, str] = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" payload: dict[str, Any] = { "model": model, "messages": messages, "stream": stream, } if "localhost" in base_url or "127.0.0.1" in base_url: payload["max_tokens"] = 4096 timeout = max(timeout, 300.0) return httpx.post( url, json=payload, headers=headers, timeout=timeout, ) def _extract_text(response: httpx.Response) -> str: """Extract the assistant message text from a non-streaming response.""" data = response.json() choices = data.get("choices", []) if not choices: logger.warning("LLM response contained no choices: %s", data) return "" return choices[0].get("message", {}).get("content", "") def _iter_stream(response: httpx.Response) -> Iterator[str]: """Yield text chunks from a non-streaming SSE response already in memory. This helper is kept for backward compatibility; prefer :func:`_stream_completion` for live streaming over an open connection. Yields ------ str Each content delta string as it arrives. """ for line in response.text.splitlines(): line = line.strip() if not line: continue if line == "data: [DONE]": break if line.startswith("data: "): raw = line[len("data: ") :] try: chunk = json.loads(raw) delta = chunk.get("choices", [{}])[0].get("delta", {}) content = delta.get("content", "") if content: yield content except (json.JSONDecodeError, IndexError, KeyError): continue def _stream_completion( base_url: str, model: str, messages: list[dict[str, str]], *, api_key: str | None = None, timeout: float = _REQUEST_TIMEOUT, ) -> Iterator[str]: """Stream a chat completion, keeping the HTTP connection open. Uses ``httpx.stream()`` so the connection stays open while the generator is being consumed. The connection is closed automatically when the generator is exhausted or garbage-collected. Parameters ---------- base_url : str The provider base URL. model : str Model identifier. messages : list[dict[str, str]] Chat messages array. api_key : str | None Bearer token, or ``None`` for local Ollama. timeout : float Request timeout in seconds. Yields ------ str Each content delta as it arrives from the SSE stream. """ url = f"{base_url}/v1/chat/completions" headers: dict[str, str] = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" payload: dict[str, Any] = { "model": model, "messages": messages, "stream": True, } if "localhost" in base_url or "127.0.0.1" in base_url: payload["max_tokens"] = 4096 timeout = max(timeout, 300.0) with httpx.stream("POST", url, json=payload, headers=headers, timeout=timeout) as resp: resp.raise_for_status() for line in resp.iter_lines(): line = line.strip() if not line: continue if line == "data: [DONE]": return if line.startswith("data: "): raw = line[len("data: ") :] try: chunk = json.loads(raw) delta = chunk.get("choices", [{}])[0].get("delta", {}) content = delta.get("content", "") if content: yield content except (json.JSONDecodeError, IndexError, KeyError): continue # --------------------------------------------------------------------------- # OllamaFreeAPI SDK-based completions # --------------------------------------------------------------------------- def _messages_to_prompt(messages: list[dict[str, str]]) -> str: """Flatten a messages array into a single prompt string for SDK providers. The ``ollamafreeapi`` SDK takes a flat ``prompt`` string, not an OpenAI ``messages`` array. This helper concatenates system, user, and assistant messages into a readable prompt that preserves context. """ parts: list[str] = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if role == "system": parts.append(f"[System: {content}]") elif role == "assistant": parts.append(f"Assistant: {content}") else: parts.append(content) return "\n\n".join(parts) _FREEAPI_TIMEOUT = 180.0 # seconds — generous for free community servers def _strip_think_blocks(text: str) -> str: """Remove DeepSeek-R1 ``<think>...</think>`` reasoning blocks from output.""" import re cleaned = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL) return cleaned.strip() def _freeapi_completion( messages: list[dict[str, str]], model: str | None = None, timeout: float | None = None, ) -> str: """Non-streaming completion via OllamaFreeAPI SDK. Wraps the blocking ``client.chat()`` call in a thread with a timeout so the TUI never hangs indefinitely. Raises ``TimeoutError`` if the call doesn't return within *timeout* seconds so that callers (like ``ask_multi``) can fall through to the next provider. """ import concurrent.futures from .fam import OllamaFreeAPI if timeout is None: timeout = _FREEAPI_TIMEOUT client = OllamaFreeAPI() prompt = _messages_to_prompt(messages) def _call() -> str: resp = client.chat(prompt=prompt, model=model or _freeapi_model(), num_predict=10000) return str(resp) if resp else "" # NOTE: Do NOT use ``with`` — ThreadPoolExecutor.__exit__ calls # shutdown(wait=True) which blocks until the hung thread finishes, # completely defeating the timeout. pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) future = pool.submit(_call) try: result = future.result(timeout=timeout) pool.shutdown(wait=False) return _strip_think_blocks(result) except concurrent.futures.TimeoutError: pool.shutdown(wait=False) logger.warning("FreeAPI call timed out after %.0fs", timeout) raise TimeoutError(f"FreeAPI did not respond within {timeout:.0f}s") from None def _freeapi_stream( messages: list[dict[str, str]], model: str | None = None, chunk_timeout: float = float(os.environ.get("ESML_LLM_CHUNK_TIMEOUT", "90")), ) -> Iterator[str]: """Streaming completion via OllamaFreeAPI SDK. Uses a background thread + queue so that each chunk has a timeout. If no chunk arrives within *chunk_timeout* seconds the stream ends gracefully instead of hanging the UI forever. """ import concurrent.futures import queue from .fam import OllamaFreeAPI client = OllamaFreeAPI() prompt = _messages_to_prompt(messages) _DONE = object() q: queue.Queue = queue.Queue() def _producer() -> None: try: for chunk in client.stream_chat(prompt=prompt, model=model or _freeapi_model(), num_predict=10000): if chunk: q.put(str(chunk)) except Exception as exc: q.put(exc) finally: q.put(_DONE) pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) pool.submit(_producer) try: while True: try: item = q.get(timeout=chunk_timeout) except queue.Empty: logger.warning("FreeAPI stream stalled (no chunk for %.0fs)", chunk_timeout) break if item is _DONE: break if isinstance(item, Exception): logger.warning("FreeAPI stream error: %s", item) break yield item finally: pool.shutdown(wait=False) # --------------------------------------------------------------------------- # Local fallback # --------------------------------------------------------------------------- _LOCAL_FALLBACK_TEXT = """\ ESML is running in local-only mode (no LLM provider detected). Available capabilities without an LLM: - esml list-modules List all analysis modules - esml run-module <name> Run a specific module against CPADS data - esml pipeline --all -y Run the full analysis pipeline - esml assistant <question> Ask the built-in rule-based agent To enable AI-assisted mode, use any of the following: 1. Install Ollama (local, private, no API key): curl -fsSL https://ollama.com/install.sh | sh ollama pull gemma4:e2b # or mistral, llama3, deepseek-r1, ... 2. Install OllamaFreeAPI (free remote models, no API key needed): pip install ollamafreeapi # That's it! Free access to llama3, mistral, qwen, deepseek, etc. 3. Set a Gemini API key (free tier available at aistudio.google.com): export GEMINI_API_KEY="your-key-here" export GEMINI_MODEL="gemini-2.0-flash" # optional, this is the default 4. Use an OpenAI-compatible endpoint (Qwen, GPT-OSS, Mistral, Groq): export LLM_API_BASE_URL="https://openrouter.ai/api/v1" export LLM_API_KEY="your-key-here" 5. Set an OpenAI API key: export OPENAI_API_KEY="your-key-here" For more information, see: https://github.com/EpiNodes/esml """ def _local_fallback(prompt: str) -> str: """Return a helpful local response when no LLM provider is available. Parameters ---------- prompt : str The user's original question (used for keyword matching). Returns ------- str A static help message with package usage guidance. """ normalized = prompt.lower() sections = [_LOCAL_FALLBACK_TEXT.strip()] # Provide topic-specific hints based on keyword matching. if "cpads" in normalized or "dataset" in normalized or "data" in normalized: contract = cpads_contract() sections.append( "CPADS data contract:\n" f" Required variables: {', '.join(contract['required_variables'])}\n" f" Expected path: {contract['expected_wrangled_path']}" ) if "ipw" in normalized or "propensity" in normalized or "causal" in normalized: sections.append( "Causal inference modules available: propensity-scores, " "causal-estimators, treatment-effects, ebac-selection-adjustment-ipw.\n" "Use `esml run-module <name>` to execute." ) if "module" in normalized or "list" in normalized: names = [spec.name for spec in MODULE_SPECS.values()] sections.append("Implemented modules: " + ", ".join(names)) return "\n\n".join(sections) # --------------------------------------------------------------------------- # Main ask() function # ---------------------------------------------------------------------------
[docs] def ask( prompt: str, context: dict[str, Any] | None = None, *, stream: bool = False, model: str | None = None, provider: str | None = None, system_prompt: str | None = None, timeout: float = _REQUEST_TIMEOUT, ) -> str | Iterator[str]: """Send a prompt to the best available LLM provider and return the response. The provider chain is: Ollama (local) -> OpenAI-compatible API -> OpenAI direct -> local fallback. Each provider is tried in order; on failure the next is attempted. Parameters ---------- prompt : str The user's question or instruction. context : dict[str, Any] | None Optional context dictionary (e.g., from :func:`build_esml_context`). Injected into the system prompt to give the LLM awareness of available modules, CPADS schema, and the user's working directory. stream : bool If ``True``, return an iterator of string chunks for streaming output. If ``False`` (default), return the full response as a single string. model : str | None Override the model identifier. When ``None``, a sensible default is chosen per provider. provider : str | None Force a specific provider (``"ollama"``, ``"api"``, ``"openai"``, ``"local"``). When ``None``, :func:`detect_available_provider` is used to auto-detect. system_prompt : str | None Override the entire system prompt. When ``None``, the standard ESML system prompt is built from the ``context`` parameter. timeout : float HTTP request timeout in seconds. Returns ------- str | Iterator[str] The LLM response text (or a streaming iterator of text chunks). When all providers fail, returns a local fallback help string. Examples -------- >>> # Non-streaming (returns full text) >>> response = ask("What is AIPW?") >>> isinstance(response, str) True >>> # Streaming >>> for chunk in ask("Explain TMLE", stream=True): ... print(chunk, end="") """ if provider is None: provider = detect_available_provider() if provider == _PROVIDER_LOCAL: result = _local_fallback(prompt) if stream: return iter([result]) return result # RAG: retrieve relevant source code for the query rag_source = _retrieve_relevant_source(prompt) if rag_source: if context is None: context = {} context["rag_source"] = rag_source # Inject last traceback if user seems to be debugging tb = get_last_traceback() if tb and any( kw in prompt.lower() for kw in ("error", "bug", "fix", "traceback", "fail", "broke", "crash", "debug") ): if context is None: context = {} context["rag_source"] = (context.get("rag_source", "") + f"\n\nLAST ERROR:\n{tb}").strip() messages = _build_messages(prompt, context=context, system_prompt=system_prompt) # --- SDK-based providers (no HTTP endpoint) --- if provider == _PROVIDER_FREEAPI: try: if stream: return _freeapi_stream(messages, model=model) else: return _freeapi_completion(messages, model=model) except Exception as exc: logger.warning("FreeAPI failed: %s. Falling through to HTTP providers.", exc) # Fall through to try HTTP-based providers below. if _gemini_key(): provider = _PROVIDER_GEMINI elif _api_base_url() and _api_key(): provider = _PROVIDER_API elif _openai_key(): provider = _PROVIDER_OPENAI else: result = _local_fallback(prompt) return iter([result]) if stream else result # Build the ordered list of (base_url, model, api_key) to try. attempts: list[tuple[str, str, str | None]] = [] if provider == _PROVIDER_OLLAMA: attempts.append( ( _ollama_base_url(), model or _ollama_model(), None, ) ) # Fallback chain if Ollama fails at request time. if _probe_freeapi(): pass # FreeAPI is SDK-based, handled above; skip in HTTP loop. if _gemini_key(): attempts.append((GEMINI_BASE_URL, model or _gemini_model(), _gemini_key())) if _api_base_url() and _api_key(): attempts.append((_api_base_url(), model or DEFAULT_API_MODEL, _api_key())) # type: ignore[arg-type] if _openai_key(): attempts.append((OPENAI_BASE_URL, model or DEFAULT_OPENAI_MODEL, _openai_key())) elif provider == _PROVIDER_GEMINI: key = _gemini_key() if key: attempts.append((GEMINI_BASE_URL, model or _gemini_model(), key)) # Fallback to generic API then OpenAI if Gemini fails. if _api_base_url() and _api_key(): attempts.append((_api_base_url(), model or DEFAULT_API_MODEL, _api_key())) # type: ignore[arg-type] if _openai_key(): attempts.append((OPENAI_BASE_URL, model or DEFAULT_OPENAI_MODEL, _openai_key())) elif provider == _PROVIDER_API: base = _api_base_url() key = _api_key() if base and key: attempts.append((base, model or DEFAULT_API_MODEL, key)) if _openai_key(): attempts.append((OPENAI_BASE_URL, model or DEFAULT_OPENAI_MODEL, _openai_key())) elif provider == _PROVIDER_OPENAI: key = _openai_key() if key: attempts.append((OPENAI_BASE_URL, model or DEFAULT_OPENAI_MODEL, key)) if not attempts: result = _local_fallback(prompt) return iter([result]) if stream else result last_error: Exception | None = None for base_url, req_model, api_key in attempts: try: logger.debug( "Attempting LLM request: base_url=%s model=%s stream=%s", base_url, req_model, stream, ) if stream: # Use _stream_completion which keeps the httpx connection open # via httpx.stream() context manager while the generator is live. return _stream_completion( base_url, req_model, messages, api_key=api_key, timeout=timeout, ) else: resp = _request_completion( base_url, req_model, messages, api_key=api_key, stream=False, timeout=timeout, ) resp.raise_for_status() return _extract_text(resp) except (httpx.HTTPError, httpx.TimeoutException, OSError, KeyError) as exc: last_error = exc logger.warning( "Provider at %s failed: %s. Trying next provider.", base_url, exc, ) continue logger.warning( "All LLM providers failed. Last error: %s. Falling back to local mode.", last_error, ) result = _local_fallback(prompt) return iter([result]) if stream else result
# --------------------------------------------------------------------------- # Multi-turn conversation support # ---------------------------------------------------------------------------
[docs] def ask_multi( messages: list[dict[str, str]], *, stream: bool = False, model: str | None = None, provider: str | None = None, timeout: float = _REQUEST_TIMEOUT, ) -> str | Iterator[str]: """Send a pre-built messages array to the best available LLM provider. Unlike :func:`ask`, this accepts the full ``messages`` array directly, enabling multi-turn conversation support. The caller is responsible for constructing the system and user messages. Parameters ---------- messages : list[dict[str, str]] The chat messages array (system, user, assistant turns). stream : bool If ``True``, return an iterator of string chunks. model : str | None Override the model identifier. provider : str | None Force a specific provider. Auto-detected when ``None``. timeout : float HTTP request timeout in seconds. Returns ------- str | Iterator[str] The LLM response text (or a streaming iterator). """ if provider is None: provider = detect_available_provider() if provider == _PROVIDER_LOCAL: # Extract the last user message for the fallback. user_msgs = [m for m in messages if m.get("role") == "user"] prompt = user_msgs[-1]["content"] if user_msgs else "" result = _local_fallback(prompt) return iter([result]) if stream else result # --- SDK-based providers (no HTTP endpoint) --- if provider == _PROVIDER_FREEAPI: try: if stream: return _freeapi_stream(messages, model=model) else: return _freeapi_completion(messages, model=model) except Exception as exc: logger.warning("FreeAPI failed in ask_multi: %s. Falling through.", exc) if _gemini_key(): provider = _PROVIDER_GEMINI elif _api_base_url() and _api_key(): provider = _PROVIDER_API elif _openai_key(): provider = _PROVIDER_OPENAI else: user_msgs = [m for m in messages if m.get("role") == "user"] prompt = user_msgs[-1]["content"] if user_msgs else "" result = _local_fallback(prompt) return iter([result]) if stream else result # Build the ordered list of (base_url, model, api_key) to try. attempts: list[tuple[str, str, str | None]] = [] if provider == _PROVIDER_OLLAMA: attempts.append((_ollama_base_url(), model or _ollama_model(), None)) if _gemini_key(): attempts.append((GEMINI_BASE_URL, model or _gemini_model(), _gemini_key())) if _api_base_url() and _api_key(): attempts.append((_api_base_url(), model or DEFAULT_API_MODEL, _api_key())) # type: ignore[arg-type] if _openai_key(): attempts.append((OPENAI_BASE_URL, model or DEFAULT_OPENAI_MODEL, _openai_key())) elif provider == _PROVIDER_GEMINI: key = _gemini_key() if key: attempts.append((GEMINI_BASE_URL, model or _gemini_model(), key)) if _api_base_url() and _api_key(): attempts.append((_api_base_url(), model or DEFAULT_API_MODEL, _api_key())) # type: ignore[arg-type] if _openai_key(): attempts.append((OPENAI_BASE_URL, model or DEFAULT_OPENAI_MODEL, _openai_key())) elif provider == _PROVIDER_API: base = _api_base_url() key = _api_key() if base and key: attempts.append((base, model or DEFAULT_API_MODEL, key)) if _openai_key(): attempts.append((OPENAI_BASE_URL, model or DEFAULT_OPENAI_MODEL, _openai_key())) elif provider == _PROVIDER_OPENAI: key = _openai_key() if key: attempts.append((OPENAI_BASE_URL, model or DEFAULT_OPENAI_MODEL, key)) if not attempts: user_msgs = [m for m in messages if m.get("role") == "user"] prompt = user_msgs[-1]["content"] if user_msgs else "" result = _local_fallback(prompt) return iter([result]) if stream else result for base_url, req_model, api_key in attempts: try: if stream: return _stream_completion( base_url, req_model, messages, api_key=api_key, timeout=timeout, ) else: resp = _request_completion( base_url, req_model, messages, api_key=api_key, stream=False, timeout=timeout, ) resp.raise_for_status() return _extract_text(resp) except (httpx.HTTPError, httpx.TimeoutException, OSError, KeyError) as exc: logger.warning("Provider at %s failed: %s", base_url, exc) continue user_msgs = [m for m in messages if m.get("role") == "user"] prompt = user_msgs[-1]["content"] if user_msgs else "" result = _local_fallback(prompt) return iter([result]) if stream else result
# --------------------------------------------------------------------------- # Agent availability check # ---------------------------------------------------------------------------
[docs] def agent_available() -> bool: """Return True when at least one live LLM provider is available. Returns ------- bool ``True`` if a live provider is detected, ``False`` if only local fallback is available. Examples -------- >>> isinstance(agent_available(), bool) True """ return detect_available_provider() != _PROVIDER_LOCAL
assistant_available = agent_available