Created
February 15, 2026 09:20
-
-
Save knjname/09d2dd4420972552cd97767200372639 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # /// script | |
| # requires-python = ">=3.13" | |
| # dependencies = [ | |
| # "fastapi", | |
| # "huggingface-hub", | |
| # "mediapipe", | |
| # "numpy", | |
| # "onnxruntime", | |
| # "pillow", | |
| # "python-multipart", | |
| # "uvicorn[standard]", | |
| # ] | |
| # /// | |
| """ | |
| 顔写真バリデーション サブAPIサーバ (single script) | |
| - POST /validate に画像を送ると、顔写真として適切かどうかを判定して返す | |
| - 判定基準: | |
| 1. anime_real_cls: アニメ/イラストでないこと (real) | |
| 2. Face Detection: 顔がちょうど1つ検出されること | |
| 3. Face Landmarker: ランドマークが取得できること(検出の裏付け) | |
| 起動: | |
| uv run server_single_script.py | |
| uv run server_single_script.py --host 0.0.0.0 --port 8100 --workers 4 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import io | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| import mediapipe as mp | |
| import numpy as np | |
| import onnxruntime as ort | |
| from fastapi import FastAPI, File, UploadFile | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| logger = logging.getLogger("face-validator") | |
| # --------------------------------------------------------------------------- | |
| # モデル管理 (ワーカープロセスごとに1つ) | |
| # --------------------------------------------------------------------------- | |
| BaseOptions = mp.tasks.BaseOptions | |
| FaceDetector = mp.tasks.vision.FaceDetector | |
| FaceDetectorOptions = mp.tasks.vision.FaceDetectorOptions | |
| FaceLandmarker = mp.tasks.vision.FaceLandmarker | |
| FaceLandmarkerOptions = mp.tasks.vision.FaceLandmarkerOptions | |
| MODELS_DIR = Path("models") | |
| def ensure_models() -> dict[str, Path]: | |
| """必要なモデルファイルをダウンロードして返す""" | |
| MODELS_DIR.mkdir(exist_ok=True) | |
| files: dict[str, Path] = {} | |
| # MediaPipe Face Detection | |
| p = MODELS_DIR / "blaze_face_short_range.tflite" | |
| if not p.exists(): | |
| import urllib.request | |
| urllib.request.urlretrieve( | |
| "https://storage.googleapis.com/mediapipe-models/face_detector/blaze_face_short_range/float16/latest/blaze_face_short_range.tflite", | |
| p, | |
| ) | |
| files["face_detector"] = p | |
| # MediaPipe Face Landmarker | |
| p = MODELS_DIR / "face_landmarker.task" | |
| if not p.exists(): | |
| import urllib.request | |
| urllib.request.urlretrieve( | |
| "https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/latest/face_landmarker.task", | |
| p, | |
| ) | |
| files["face_landmarker"] = p | |
| # anime_real_cls (ONNX) | |
| files["anime_real_cls"] = Path( | |
| hf_hub_download("deepghs/anime_real_cls", "mobilenetv3_v1.4_dist/model.onnx") | |
| ) | |
| return files | |
| @dataclass | |
| class Models: | |
| face_detector: FaceDetector = field(repr=False) | |
| face_landmarker: FaceLandmarker = field(repr=False) | |
| anime_real_session: ort.InferenceSession = field(repr=False) | |
| anime_real_labels: list[str] = field(default_factory=lambda: ["anime", "real"]) | |
| _models: Models | None = None | |
| def load_models() -> Models: | |
| files = ensure_models() | |
| detector = FaceDetector.create_from_options( | |
| FaceDetectorOptions( | |
| base_options=BaseOptions(model_asset_path=str(files["face_detector"])), | |
| min_detection_confidence=0.5, | |
| ) | |
| ) | |
| landmarker = FaceLandmarker.create_from_options( | |
| FaceLandmarkerOptions( | |
| base_options=BaseOptions(model_asset_path=str(files["face_landmarker"])), | |
| num_faces=5, | |
| min_face_detection_confidence=0.5, | |
| min_face_presence_confidence=0.5, | |
| ) | |
| ) | |
| anime_session = ort.InferenceSession( | |
| str(files["anime_real_cls"]), | |
| providers=["CPUExecutionProvider"], | |
| ) | |
| return Models( | |
| face_detector=detector, | |
| face_landmarker=landmarker, | |
| anime_real_session=anime_session, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # 推論ロジック | |
| # --------------------------------------------------------------------------- | |
| def classify_anime_real( | |
| session: ort.InferenceSession, pil_img: Image.Image | |
| ) -> dict[str, float]: | |
| """anime_real_cls で anime/real スコアを返す""" | |
| img = pil_img.resize((384, 384)) | |
| arr = np.array(img, dtype=np.float32) / 255.0 | |
| # HWC -> CHW, add batch | |
| arr = np.transpose(arr, (2, 0, 1))[np.newaxis, ...] | |
| (output,) = session.run(None, {"input": arr}) | |
| # softmax | |
| exp = np.exp(output[0] - np.max(output[0])) | |
| probs = exp / exp.sum() | |
| return {"anime": float(probs[0]), "real": float(probs[1])} | |
| @dataclass | |
| class ValidationResult: | |
| valid: bool | |
| reasons: list[str] | |
| anime_real: dict[str, float] | |
| face_count: int | |
| face_detection_confidence: float | None | |
| landmark_count: int | None | |
| def validate_face_photo(models: Models, image_bytes: bytes) -> ValidationResult: | |
| """画像バイト列を受け取って顔写真バリデーション結果を返す""" | |
| reasons: list[str] = [] | |
| # 画像デコード | |
| pil_img = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| rgb_array = np.array(pil_img) | |
| # 1) anime_real_cls | |
| ar_scores = classify_anime_real(models.anime_real_session, pil_img) | |
| is_real = ar_scores.get("real", 0) > ar_scores.get("anime", 0) | |
| if not is_real: | |
| reasons.append( | |
| f"anime_real_cls がイラストと判定 (anime={ar_scores['anime']:.4f})" | |
| ) | |
| # 2) Face Detection | |
| mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_array) | |
| det_result = models.face_detector.detect(mp_image) | |
| face_count = len(det_result.detections) | |
| det_confidence = None | |
| if face_count == 0: | |
| reasons.append("顔が検出されませんでした") | |
| elif face_count > 1: | |
| reasons.append(f"複数の顔が検出されました ({face_count}件)") | |
| else: | |
| det_confidence = det_result.detections[0].categories[0].score | |
| # 3) Face Landmarker | |
| lm_result = models.face_landmarker.detect(mp_image) | |
| lm_count = None | |
| if lm_result.face_landmarks: | |
| if len(lm_result.face_landmarks) == 1: | |
| lm_count = len(lm_result.face_landmarks[0]) | |
| else: | |
| reasons.append( | |
| f"Face Landmarker が複数の顔を検出 ({len(lm_result.face_landmarks)}件)" | |
| ) | |
| lm_count = len(lm_result.face_landmarks[0]) | |
| else: | |
| if face_count > 0: | |
| reasons.append("顔検出はされたがランドマークが取得できませんでした") | |
| return ValidationResult( | |
| valid=len(reasons) == 0, | |
| reasons=reasons, | |
| anime_real=ar_scores, | |
| face_count=face_count, | |
| face_detection_confidence=det_confidence, | |
| landmark_count=lm_count, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # FastAPI | |
| # --------------------------------------------------------------------------- | |
| @asynccontextmanager | |
| async def lifespan(app: FastAPI): | |
| global _models | |
| logger.info("Loading models...") | |
| _models = load_models() | |
| logger.info("Models loaded.") | |
| yield | |
| if _models: | |
| _models.face_detector.close() | |
| _models.face_landmarker.close() | |
| _models = None | |
| app = FastAPI(title="Face Photo Validator", lifespan=lifespan) | |
| @app.post("/validate") | |
| async def validate(file: UploadFile = File(...)): | |
| assert _models is not None | |
| image_bytes = await file.read() | |
| result = validate_face_photo(_models, image_bytes) | |
| return { | |
| "valid": result.valid, | |
| "reasons": result.reasons, | |
| "details": { | |
| "anime_real": result.anime_real, | |
| "face_count": result.face_count, | |
| "face_detection_confidence": result.face_detection_confidence, | |
| "landmark_count": result.landmark_count, | |
| }, | |
| } | |
| @app.get("/health") | |
| async def health(): | |
| return {"status": "ok", "models_loaded": _models is not None} | |
| # --------------------------------------------------------------------------- | |
| # entrypoint | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| import uvicorn | |
| parser = argparse.ArgumentParser(description="Face Photo Validator API Server") | |
| parser.add_argument("--host", default="127.0.0.1") | |
| parser.add_argument("--port", type=int, default=8100) | |
| parser.add_argument("--workers", type=int, default=4) | |
| args = parser.parse_args() | |
| uvicorn.run( | |
| "server_single_script:app", | |
| host=args.host, | |
| port=args.port, | |
| workers=args.workers, | |
| ) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment