Skip to content

Instantly share code, notes, and snippets.

@sincerefly
Last active December 26, 2025 02:48
Show Gist options
  • Select an option

  • Save sincerefly/288808ea15d23db8c8f57896852dd465 to your computer and use it in GitHub Desktop.

Select an option

Save sincerefly/288808ea15d23db8c8f57896852dd465 to your computer and use it in GitHub Desktop.
一个使用 tanaos/tanaos-text-anonymizer-v1 模型搭配正则表达式共同识别的示例

Usage

# 创建环境
python3 -m venv venv
source venv/bin/activate

# 安装依赖
pip3 install torch tabulate 

# 运行实例
python3 main.py

Output

Type            Text                                      Span       Conf  Source
--------------  ----------------------------------------  -------  ------  --------
PERSON          John Smith                                23:33      0.86  Model
EMAIL           john.smith@company.com                    59:81      1     Regex
PHONE_NUMBER    +1-202-555-0199                           85:100     1     Regex
LOCATION        New York                                  116:124    0.8   Model
PERSON          Ren                                       140:143    0.98  Model
PERSON          ée                                        143:145    0.98  Model
LOCATION        Paris                                     155:160    0.97  Model
LOCATION        Los Angeles                               225:236    0.75  Model
DB_CONNECTION   postgres://admin:admin_passwd@postgre...  251:371    1     Regex
AWS_ACCESS_KEY  AKIAW18MM72SVWNPDCHD                      387:407    1     Regex
AWS_SECRET_KEY  DZmIeyI32ivK3nIOi2ic94LITNyAm0yrONCp1SNK  424:464    1     Regex
import re
from dataclasses import dataclass
from typing import List, Optional, Tuple
from transformers import pipeline
from tabulate import tabulate
# ============================================================
# 1. Configuration
# ============================================================
MODEL_NAME = "tanaos/tanaos-text-anonymizer-v1"
CONFIDENCE_THRESHOLD = 0.6
# 仅信任模型最稳定的标签
ACCEPTED_MODEL_LABELS = {"PERSON", "LOCATION"}
# 所有 span 均为 [start, end) 半开区间
# 正则匹配优先级从上到下,先匹配者“抢占”文本区域
REGEX_PRIORITY_LIST: List[Tuple[str, re.Pattern]] = [
(
"DB_CONNECTION",
re.compile(r"[a-zA-Z0-9+.-]+://\S+"),
),
(
"AWS_ACCESS_KEY",
re.compile(r"\b(AKIA|ASIA|ABIA|ACCA)[0-9A-Z]{16}\b"),
),
(
"AWS_SECRET_KEY",
re.compile(r"(?<![A-Za-z0-9/+=])[A-Za-z0-9/+=]{40}(?![A-Za-z0-9/+=])"),
),
(
"EMAIL",
re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"),
),
(
"PHONE_NUMBER",
re.compile(
r"(?:\+?\d{1,3}[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}"
),
),
]
# ============================================================
# 2. Data Structures
# ============================================================
@dataclass(frozen=True)
class Entity:
type: str
text: str
start: int
end: int
score: float
source: str # "Regex" | "Model"
# ============================================================
# 3. Pipeline Initialization (Lazy)
# ============================================================
_ner_pipeline = None
def get_ner_pipeline():
global _ner_pipeline
if _ner_pipeline is None:
_ner_pipeline = pipeline(
"token-classification",
model=MODEL_NAME,
aggregation_strategy="simple",
)
return _ner_pipeline
# ============================================================
# 4. Utility Functions
# ============================================================
def has_overlap(start: int, end: int, entities: List[Entity]) -> bool:
"""检查新区间是否与已有实体重叠(半开区间)"""
for e in entities:
if max(start, e.start) < min(end, e.end):
return True
return False
def refine_entity_boundaries(
text: str, start: int, end: int
) -> Optional[Tuple[str, int, int]]:
"""
清洗模型输出的 span:
- 去除首尾常见标点
- 修正 start / end
"""
raw_span = text[start:end]
clean_span = raw_span.strip(" .,;?!-\n\"'<>")
if not clean_span:
return None
offset = raw_span.find(clean_span)
real_start = start + offset
real_end = real_start + len(clean_span)
return clean_span, real_start, real_end
# ============================================================
# 5. Extraction Logic
# ============================================================
def extract_regex_entities(text: str) -> List[Entity]:
"""按优先级执行正则提取(确定性、高置信)"""
entities: List[Entity] = []
for label, pattern in REGEX_PRIORITY_LIST:
for m in pattern.finditer(text):
start, end = m.start(), m.end()
if has_overlap(start, end, entities):
continue
entities.append(
Entity(
type=label,
text=m.group(),
start=start,
end=end,
score=1.0,
source="Regex",
)
)
return entities
def extract_model_entities(
text: str, existing_entities: List[Entity]
) -> List[Entity]:
"""模型补充提取(非确定性、需严格过滤)"""
entities: List[Entity] = []
ner = get_ner_pipeline()
results = ner(text)
if isinstance(results, dict):
results = [results]
for r in results:
label = r["entity_group"]
if label not in ACCEPTED_MODEL_LABELS:
continue
if r["score"] < CONFIDENCE_THRESHOLD:
continue
refined = refine_entity_boundaries(text, r["start"], r["end"])
if not refined:
continue
clean_text, start, end = refined
if has_overlap(start, end, existing_entities + entities):
continue
entities.append(
Entity(
type=label,
text=clean_text,
start=start,
end=end,
score=float(r["score"]),
source="Model",
)
)
return entities
def extract_all_entities(text: str) -> List[Entity]:
"""主入口:规则优先,模型补充"""
regex_entities = extract_regex_entities(text)
model_entities = extract_model_entities(text, regex_entities)
all_entities = regex_entities + model_entities
all_entities.sort(key=lambda e: e.start)
return all_entities
# ============================================================
# 6. Demo / CLI Usage
# ============================================================
if __name__ == "__main__":
text_input = """
Hello, my name is John Smith.
You can reach me at john.smith@company.com or +1-202-555-0199.
I live in New York, and my friend Renée lives in Paris!
我的朋友张三、李四明年将要毕业,计划到北京工作,他们的老家在 Harbin,但是模型不能识别,那么他们的老家在 Los Angeles!
db_url: postgres://admin:admin_passwd@postgre-cluster.cluster-abcdefg.us-east-1.rds.amazonaws.com:4432/my_server?sslmode=disable
secret_id: AKIAW18MM72SVWNPDCHD
secret_key: DZmIeyI32ivK3nIOi2ic94LITNyAm0yrONCp1SNK
"""
print("-" * 72)
print("Extracted Entities:\n")
entities = extract_all_entities(text_input)
table = [
[
e.type,
e.text if len(e.text) <= 40 else e.text[:37] + "...",
f"{e.start}:{e.end}",
f"{e.score:.2f}",
e.source,
]
for e in entities
]
print(
tabulate(
table,
headers=["Type", "Text", "Span", "Conf", "Source"],
tablefmt="simple",
)
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment