259 lines
8.4 KiB
Python
259 lines
8.4 KiB
Python
"""
|
|
claude-code-proxy: Ollama-compatible HTTP facade for `claude -p`.
|
|
|
|
Exposes a subset of the Ollama API on http://127.0.0.1:11435 and translates
|
|
each request into a `claude -p` subprocess invocation. This lets external
|
|
tools that already speak Ollama (Open WebUI, AnythingLLM, n8n nodes, etc.)
|
|
talk to Claude Code instead of a local Ollama instance.
|
|
|
|
Endpoints:
|
|
GET / health check
|
|
GET /api/version Ollama version stub
|
|
GET /api/tags list "models" (so clients can validate)
|
|
POST /api/show model details stub
|
|
POST /api/generate single-shot prompt -> response
|
|
POST /api/chat multi-message conversation -> response
|
|
|
|
Both /api/generate and /api/chat honour the `stream` flag in the request
|
|
body (Ollama default is True). When true, responses are emitted as
|
|
NDJSON chunks; when false, a single JSON object is returned.
|
|
|
|
Environment variables:
|
|
CLAUDE_BIN path to claude CLI (default: "claude")
|
|
CLAUDE_PROXY_CONCURRENCY max concurrent claude subprocesses (default: 3)
|
|
CLAUDE_PROXY_MODEL name advertised in /api/tags (default: "claude-code")
|
|
CLAUDE_PROXY_TIMEOUT per-request timeout in seconds (default: 300)
|
|
CLAUDE_CODE_OAUTH_TOKEN long-lived auth token, inherited by claude subprocess
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from datetime import datetime, timezone
|
|
from typing import Any, AsyncIterator
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
# --- Configuration ----------------------------------------------------------
|
|
|
|
CLAUDE_BIN = os.environ.get("CLAUDE_BIN", "claude")
|
|
CONCURRENCY = int(os.environ.get("CLAUDE_PROXY_CONCURRENCY", "3"))
|
|
DEFAULT_MODEL = os.environ.get("CLAUDE_PROXY_MODEL", "claude-code")
|
|
TIMEOUT_SECONDS = int(os.environ.get("CLAUDE_PROXY_TIMEOUT", "300"))
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
log = logging.getLogger("claude-proxy")
|
|
|
|
app = FastAPI(title="claude-code-proxy")
|
|
_semaphore = asyncio.Semaphore(CONCURRENCY)
|
|
|
|
|
|
# --- Helpers ----------------------------------------------------------------
|
|
|
|
def _now_iso() -> str:
|
|
return datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
async def _run_claude(prompt: str) -> str:
|
|
"""Run `claude -p <prompt>` and return stdout as a string."""
|
|
async with _semaphore:
|
|
log.info("claude -p invoked (prompt %d chars)", len(prompt))
|
|
proc = await asyncio.create_subprocess_exec(
|
|
CLAUDE_BIN, "-p", prompt,
|
|
stdout=asyncio.subprocess.PIPE,
|
|
stderr=asyncio.subprocess.PIPE,
|
|
env=os.environ.copy(),
|
|
)
|
|
try:
|
|
stdout, stderr = await asyncio.wait_for(
|
|
proc.communicate(), timeout=TIMEOUT_SECONDS
|
|
)
|
|
except asyncio.TimeoutError:
|
|
proc.kill()
|
|
await proc.wait()
|
|
raise RuntimeError(f"claude -p timed out after {TIMEOUT_SECONDS}s")
|
|
|
|
if proc.returncode != 0:
|
|
err = stderr.decode("utf-8", errors="replace")[:1000]
|
|
raise RuntimeError(f"claude -p exited {proc.returncode}: {err}")
|
|
|
|
return stdout.decode("utf-8", errors="replace")
|
|
|
|
|
|
def _build_prompt_from_messages(messages: list[dict]) -> str:
|
|
"""Flatten OpenAI/Ollama-style messages into a single prompt string."""
|
|
system_parts = [m["content"] for m in messages if m.get("role") == "system"]
|
|
convo: list[str] = []
|
|
for m in messages:
|
|
role = m.get("role", "user")
|
|
if role == "system":
|
|
continue
|
|
prefix = "User" if role == "user" else "Assistant"
|
|
convo.append(f"{prefix}: {m.get('content', '')}")
|
|
convo.append("Assistant:")
|
|
body = "\n\n".join(convo)
|
|
if system_parts:
|
|
return "[System]\n" + "\n\n".join(system_parts) + "\n\n" + body
|
|
return body
|
|
|
|
|
|
# --- Streaming generators ---------------------------------------------------
|
|
|
|
async def _stream_generate(base: dict, text: str) -> AsyncIterator[bytes]:
|
|
"""Emit Ollama-style NDJSON for /api/generate: incremental chunks then done."""
|
|
chunk_size = 64
|
|
started = time.time()
|
|
for i in range(0, len(text), chunk_size):
|
|
frame = {**base, "response": text[i:i + chunk_size], "done": False}
|
|
yield (json.dumps(frame) + "\n").encode("utf-8")
|
|
await asyncio.sleep(0)
|
|
final = {
|
|
**base,
|
|
"response": "",
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
"total_duration": int((time.time() - started) * 1e9),
|
|
}
|
|
yield (json.dumps(final) + "\n").encode("utf-8")
|
|
|
|
|
|
async def _stream_chat(base: dict, text: str) -> AsyncIterator[bytes]:
|
|
"""Emit Ollama-style NDJSON for /api/chat: each frame carries a message."""
|
|
chunk_size = 64
|
|
started = time.time()
|
|
for i in range(0, len(text), chunk_size):
|
|
frame = {
|
|
**base,
|
|
"message": {"role": "assistant", "content": text[i:i + chunk_size]},
|
|
"done": False,
|
|
}
|
|
yield (json.dumps(frame) + "\n").encode("utf-8")
|
|
await asyncio.sleep(0)
|
|
final = {
|
|
**base,
|
|
"message": {"role": "assistant", "content": ""},
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
"total_duration": int((time.time() - started) * 1e9),
|
|
}
|
|
yield (json.dumps(final) + "\n").encode("utf-8")
|
|
|
|
|
|
# --- Routes -----------------------------------------------------------------
|
|
|
|
@app.get("/")
|
|
async def root() -> dict:
|
|
return {"status": "ok", "service": "claude-code-proxy"}
|
|
|
|
|
|
@app.get("/api/version")
|
|
async def version() -> dict:
|
|
return {"version": "0.1.0-claude-proxy"}
|
|
|
|
|
|
@app.get("/api/tags")
|
|
async def tags() -> dict:
|
|
"""Ollama-style model list. Many clients hit this to verify the endpoint."""
|
|
return {
|
|
"models": [{
|
|
"name": DEFAULT_MODEL,
|
|
"model": DEFAULT_MODEL,
|
|
"modified_at": _now_iso(),
|
|
"size": 0,
|
|
"digest": "sha256:claude-code",
|
|
"details": {
|
|
"parent_model": "",
|
|
"format": "claude",
|
|
"family": "claude",
|
|
"families": ["claude"],
|
|
"parameter_size": "unknown",
|
|
"quantization_level": "none",
|
|
},
|
|
}]
|
|
}
|
|
|
|
|
|
@app.post("/api/show")
|
|
async def show(req: Request) -> dict:
|
|
body = await req.json()
|
|
name = body.get("name", DEFAULT_MODEL)
|
|
return {
|
|
"modelfile": f"FROM {name}",
|
|
"parameters": "",
|
|
"template": "",
|
|
"details": {
|
|
"format": "claude",
|
|
"family": "claude",
|
|
"parameter_size": "unknown",
|
|
"quantization_level": "none",
|
|
},
|
|
}
|
|
|
|
|
|
@app.post("/api/generate")
|
|
async def generate(req: Request) -> Any:
|
|
body = await req.json()
|
|
model = body.get("model", DEFAULT_MODEL)
|
|
prompt = body.get("prompt", "")
|
|
system = body.get("system")
|
|
stream = bool(body.get("stream", True))
|
|
|
|
full_prompt = f"[System]\n{system}\n\n{prompt}" if system else prompt
|
|
|
|
started = time.time()
|
|
try:
|
|
text = await _run_claude(full_prompt)
|
|
except Exception as e:
|
|
log.exception("claude invocation failed")
|
|
return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
base = {"model": model, "created_at": _now_iso()}
|
|
if stream:
|
|
return StreamingResponse(
|
|
_stream_generate(base, text),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
return {
|
|
**base,
|
|
"response": text,
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
"total_duration": int((time.time() - started) * 1e9),
|
|
}
|
|
|
|
|
|
@app.post("/api/chat")
|
|
async def chat(req: Request) -> Any:
|
|
body = await req.json()
|
|
model = body.get("model", DEFAULT_MODEL)
|
|
messages = body.get("messages", [])
|
|
stream = bool(body.get("stream", True))
|
|
|
|
prompt = _build_prompt_from_messages(messages)
|
|
|
|
started = time.time()
|
|
try:
|
|
text = await _run_claude(prompt)
|
|
except Exception as e:
|
|
log.exception("claude invocation failed")
|
|
return JSONResponse({"error": str(e)}, status_code=500)
|
|
|
|
base = {"model": model, "created_at": _now_iso()}
|
|
if stream:
|
|
return StreamingResponse(
|
|
_stream_chat(base, text),
|
|
media_type="application/x-ndjson",
|
|
)
|
|
return {
|
|
**base,
|
|
"message": {"role": "assistant", "content": text},
|
|
"done": True,
|
|
"done_reason": "stop",
|
|
"total_duration": int((time.time() - started) * 1e9),
|
|
}
|