Skip to content

Instantly share code, notes, and snippets.

@knjname
Created February 15, 2026 09:20
Show Gist options
  • Select an option

  • Save knjname/09d2dd4420972552cd97767200372639 to your computer and use it in GitHub Desktop.

Select an option

Save knjname/09d2dd4420972552cd97767200372639 to your computer and use it in GitHub Desktop.
# /// script
# requires-python = ">=3.13"
# dependencies = [
# "fastapi",
# "huggingface-hub",
# "mediapipe",
# "numpy",
# "onnxruntime",
# "pillow",
# "python-multipart",
# "uvicorn[standard]",
# ]
# ///
"""
顔写真バリデーション サブAPIサーバ (single script)
- POST /validate に画像を送ると、顔写真として適切かどうかを判定して返す
- 判定基準:
1. anime_real_cls: アニメ/イラストでないこと (real)
2. Face Detection: 顔がちょうど1つ検出されること
3. Face Landmarker: ランドマークが取得できること(検出の裏付け)
起動:
uv run server_single_script.py
uv run server_single_script.py --host 0.0.0.0 --port 8100 --workers 4
"""
from __future__ import annotations
import argparse
import io
import logging
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from pathlib import Path
import mediapipe as mp
import numpy as np
import onnxruntime as ort
from fastapi import FastAPI, File, UploadFile
from huggingface_hub import hf_hub_download
from PIL import Image
logger = logging.getLogger("face-validator")
# ---------------------------------------------------------------------------
# モデル管理 (ワーカープロセスごとに1つ)
# ---------------------------------------------------------------------------
BaseOptions = mp.tasks.BaseOptions
FaceDetector = mp.tasks.vision.FaceDetector
FaceDetectorOptions = mp.tasks.vision.FaceDetectorOptions
FaceLandmarker = mp.tasks.vision.FaceLandmarker
FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions
MODELS_DIR = Path("models")
def ensure_models() -> dict[str, Path]:
"""必要なモデルファイルをダウンロードして返す"""
MODELS_DIR.mkdir(exist_ok=True)
files: dict[str, Path] = {}
# MediaPipe Face Detection
p = MODELS_DIR / "blaze_face_short_range.tflite"
if not p.exists():
import urllib.request
urllib.request.urlretrieve(
"https://storage.googleapis.com/mediapipe-models/face_detector/blaze_face_short_range/float16/latest/blaze_face_short_range.tflite",
p,
)
files["face_detector"] = p
# MediaPipe Face Landmarker
p = MODELS_DIR / "face_landmarker.task"
if not p.exists():
import urllib.request
urllib.request.urlretrieve(
"https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/latest/face_landmarker.task",
p,
)
files["face_landmarker"] = p
# anime_real_cls (ONNX)
files["anime_real_cls"] = Path(
hf_hub_download("deepghs/anime_real_cls", "mobilenetv3_v1.4_dist/model.onnx")
)
return files
@dataclass
class Models:
face_detector: FaceDetector = field(repr=False)
face_landmarker: FaceLandmarker = field(repr=False)
anime_real_session: ort.InferenceSession = field(repr=False)
anime_real_labels: list[str] = field(default_factory=lambda: ["anime", "real"])
_models: Models | None = None
def load_models() -> Models:
files = ensure_models()
detector = FaceDetector.create_from_options(
FaceDetectorOptions(
base_options=BaseOptions(model_asset_path=str(files["face_detector"])),
min_detection_confidence=0.5,
)
)
landmarker = FaceLandmarker.create_from_options(
FaceLandmarkerOptions(
base_options=BaseOptions(model_asset_path=str(files["face_landmarker"])),
num_faces=5,
min_face_detection_confidence=0.5,
min_face_presence_confidence=0.5,
)
)
anime_session = ort.InferenceSession(
str(files["anime_real_cls"]),
providers=["CPUExecutionProvider"],
)
return Models(
face_detector=detector,
face_landmarker=landmarker,
anime_real_session=anime_session,
)
# ---------------------------------------------------------------------------
# 推論ロジック
# ---------------------------------------------------------------------------
def classify_anime_real(
session: ort.InferenceSession, pil_img: Image.Image
) -> dict[str, float]:
"""anime_real_cls で anime/real スコアを返す"""
img = pil_img.resize((384, 384))
arr = np.array(img, dtype=np.float32) / 255.0
# HWC -> CHW, add batch
arr = np.transpose(arr, (2, 0, 1))[np.newaxis, ...]
(output,) = session.run(None, {"input": arr})
# softmax
exp = np.exp(output[0] - np.max(output[0]))
probs = exp / exp.sum()
return {"anime": float(probs[0]), "real": float(probs[1])}
@dataclass
class ValidationResult:
valid: bool
reasons: list[str]
anime_real: dict[str, float]
face_count: int
face_detection_confidence: float | None
landmark_count: int | None
def validate_face_photo(models: Models, image_bytes: bytes) -> ValidationResult:
"""画像バイト列を受け取って顔写真バリデーション結果を返す"""
reasons: list[str] = []
# 画像デコード
pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
rgb_array = np.array(pil_img)
# 1) anime_real_cls
ar_scores = classify_anime_real(models.anime_real_session, pil_img)
is_real = ar_scores.get("real", 0) > ar_scores.get("anime", 0)
if not is_real:
reasons.append(
f"anime_real_cls がイラストと判定 (anime={ar_scores['anime']:.4f})"
)
# 2) Face Detection
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_array)
det_result = models.face_detector.detect(mp_image)
face_count = len(det_result.detections)
det_confidence = None
if face_count == 0:
reasons.append("顔が検出されませんでした")
elif face_count > 1:
reasons.append(f"複数の顔が検出されました ({face_count}件)")
else:
det_confidence = det_result.detections[0].categories[0].score
# 3) Face Landmarker
lm_result = models.face_landmarker.detect(mp_image)
lm_count = None
if lm_result.face_landmarks:
if len(lm_result.face_landmarks) == 1:
lm_count = len(lm_result.face_landmarks[0])
else:
reasons.append(
f"Face Landmarker が複数の顔を検出 ({len(lm_result.face_landmarks)}件)"
)
lm_count = len(lm_result.face_landmarks[0])
else:
if face_count > 0:
reasons.append("顔検出はされたがランドマークが取得できませんでした")
return ValidationResult(
valid=len(reasons) == 0,
reasons=reasons,
anime_real=ar_scores,
face_count=face_count,
face_detection_confidence=det_confidence,
landmark_count=lm_count,
)
# ---------------------------------------------------------------------------
# FastAPI
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
global _models
logger.info("Loading models...")
_models = load_models()
logger.info("Models loaded.")
yield
if _models:
_models.face_detector.close()
_models.face_landmarker.close()
_models = None
app = FastAPI(title="Face Photo Validator", lifespan=lifespan)
@app.post("/validate")
async def validate(file: UploadFile = File(...)):
assert _models is not None
image_bytes = await file.read()
result = validate_face_photo(_models, image_bytes)
return {
"valid": result.valid,
"reasons": result.reasons,
"details": {
"anime_real": result.anime_real,
"face_count": result.face_count,
"face_detection_confidence": result.face_detection_confidence,
"landmark_count": result.landmark_count,
},
}
@app.get("/health")
async def health():
return {"status": "ok", "models_loaded": _models is not None}
# ---------------------------------------------------------------------------
# entrypoint
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
parser = argparse.ArgumentParser(description="Face Photo Validator API Server")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=8100)
parser.add_argument("--workers", type=int, default=4)
args = parser.parse_args()
uvicorn.run(
"server_single_script:app",
host=args.host,
port=args.port,
workers=args.workers,
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment