|
import re |
|
from dataclasses import dataclass |
|
from typing import List, Optional, Tuple |
|
from transformers import pipeline |
|
from tabulate import tabulate |
|
|
|
# ============================================================ |
|
# 1. Configuration |
|
# ============================================================ |
|
|
|
MODEL_NAME = "tanaos/tanaos-text-anonymizer-v1" |
|
CONFIDENCE_THRESHOLD = 0.6 |
|
|
|
# 仅信任模型最稳定的标签 |
|
ACCEPTED_MODEL_LABELS = {"PERSON", "LOCATION"} |
|
|
|
# 所有 span 均为 [start, end) 半开区间 |
|
# 正则匹配优先级从上到下,先匹配者“抢占”文本区域 |
|
REGEX_PRIORITY_LIST: List[Tuple[str, re.Pattern]] = [ |
|
( |
|
"DB_CONNECTION", |
|
re.compile(r"[a-zA-Z0-9+.-]+://\S+"), |
|
), |
|
( |
|
"AWS_ACCESS_KEY", |
|
re.compile(r"\b(AKIA|ASIA|ABIA|ACCA)[0-9A-Z]{16}\b"), |
|
), |
|
( |
|
"AWS_SECRET_KEY", |
|
re.compile(r"(?<![A-Za-z0-9/+=])[A-Za-z0-9/+=]{40}(?![A-Za-z0-9/+=])"), |
|
), |
|
( |
|
"EMAIL", |
|
re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"), |
|
), |
|
( |
|
"PHONE_NUMBER", |
|
re.compile( |
|
r"(?:\+?\d{1,3}[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}" |
|
), |
|
), |
|
] |
|
|
|
# ============================================================ |
|
# 2. Data Structures |
|
# ============================================================ |
|
|
|
@dataclass(frozen=True) |
|
class Entity: |
|
type: str |
|
text: str |
|
start: int |
|
end: int |
|
score: float |
|
source: str # "Regex" | "Model" |
|
|
|
|
|
# ============================================================ |
|
# 3. Pipeline Initialization (Lazy) |
|
# ============================================================ |
|
|
|
_ner_pipeline = None |
|
|
|
def get_ner_pipeline(): |
|
global _ner_pipeline |
|
if _ner_pipeline is None: |
|
_ner_pipeline = pipeline( |
|
"token-classification", |
|
model=MODEL_NAME, |
|
aggregation_strategy="simple", |
|
) |
|
return _ner_pipeline |
|
|
|
|
|
# ============================================================ |
|
# 4. Utility Functions |
|
# ============================================================ |
|
|
|
def has_overlap(start: int, end: int, entities: List[Entity]) -> bool: |
|
"""检查新区间是否与已有实体重叠(半开区间)""" |
|
for e in entities: |
|
if max(start, e.start) < min(end, e.end): |
|
return True |
|
return False |
|
|
|
|
|
def refine_entity_boundaries( |
|
text: str, start: int, end: int |
|
) -> Optional[Tuple[str, int, int]]: |
|
""" |
|
清洗模型输出的 span: |
|
- 去除首尾常见标点 |
|
- 修正 start / end |
|
""" |
|
raw_span = text[start:end] |
|
clean_span = raw_span.strip(" .,;?!-\n\"'<>") |
|
|
|
if not clean_span: |
|
return None |
|
|
|
offset = raw_span.find(clean_span) |
|
real_start = start + offset |
|
real_end = real_start + len(clean_span) |
|
|
|
return clean_span, real_start, real_end |
|
|
|
|
|
# ============================================================ |
|
# 5. Extraction Logic |
|
# ============================================================ |
|
|
|
def extract_regex_entities(text: str) -> List[Entity]: |
|
"""按优先级执行正则提取(确定性、高置信)""" |
|
entities: List[Entity] = [] |
|
|
|
for label, pattern in REGEX_PRIORITY_LIST: |
|
for m in pattern.finditer(text): |
|
start, end = m.start(), m.end() |
|
|
|
if has_overlap(start, end, entities): |
|
continue |
|
|
|
entities.append( |
|
Entity( |
|
type=label, |
|
text=m.group(), |
|
start=start, |
|
end=end, |
|
score=1.0, |
|
source="Regex", |
|
) |
|
) |
|
|
|
return entities |
|
|
|
|
|
def extract_model_entities( |
|
text: str, existing_entities: List[Entity] |
|
) -> List[Entity]: |
|
"""模型补充提取(非确定性、需严格过滤)""" |
|
entities: List[Entity] = [] |
|
ner = get_ner_pipeline() |
|
results = ner(text) |
|
if isinstance(results, dict): |
|
results = [results] |
|
|
|
for r in results: |
|
label = r["entity_group"] |
|
|
|
if label not in ACCEPTED_MODEL_LABELS: |
|
continue |
|
if r["score"] < CONFIDENCE_THRESHOLD: |
|
continue |
|
|
|
refined = refine_entity_boundaries(text, r["start"], r["end"]) |
|
if not refined: |
|
continue |
|
|
|
clean_text, start, end = refined |
|
|
|
if has_overlap(start, end, existing_entities + entities): |
|
continue |
|
|
|
entities.append( |
|
Entity( |
|
type=label, |
|
text=clean_text, |
|
start=start, |
|
end=end, |
|
score=float(r["score"]), |
|
source="Model", |
|
) |
|
) |
|
|
|
return entities |
|
|
|
|
|
def extract_all_entities(text: str) -> List[Entity]: |
|
"""主入口:规则优先,模型补充""" |
|
regex_entities = extract_regex_entities(text) |
|
model_entities = extract_model_entities(text, regex_entities) |
|
|
|
all_entities = regex_entities + model_entities |
|
all_entities.sort(key=lambda e: e.start) |
|
|
|
return all_entities |
|
|
|
|
|
# ============================================================ |
|
# 6. Demo / CLI Usage |
|
# ============================================================ |
|
|
|
if __name__ == "__main__": |
|
text_input = """ |
|
Hello, my name is John Smith. |
|
You can reach me at john.smith@company.com or +1-202-555-0199. |
|
I live in New York, and my friend Renée lives in Paris! |
|
|
|
我的朋友张三、李四明年将要毕业,计划到北京工作,他们的老家在 Harbin,但是模型不能识别,那么他们的老家在 Los Angeles! |
|
|
|
db_url: postgres://admin:admin_passwd@postgre-cluster.cluster-abcdefg.us-east-1.rds.amazonaws.com:4432/my_server?sslmode=disable |
|
secret_id: AKIAW18MM72SVWNPDCHD |
|
secret_key: DZmIeyI32ivK3nIOi2ic94LITNyAm0yrONCp1SNK |
|
""" |
|
|
|
print("-" * 72) |
|
print("Extracted Entities:\n") |
|
|
|
entities = extract_all_entities(text_input) |
|
|
|
table = [ |
|
[ |
|
e.type, |
|
e.text if len(e.text) <= 40 else e.text[:37] + "...", |
|
f"{e.start}:{e.end}", |
|
f"{e.score:.2f}", |
|
e.source, |
|
] |
|
for e in entities |
|
] |
|
|
|
print( |
|
tabulate( |
|
table, |
|
headers=["Type", "Text", "Span", "Conf", "Source"], |
|
tablefmt="simple", |
|
) |
|
) |