Skip to content

Instantly share code, notes, and snippets.

@zenz
Created June 26, 2025 10:06
Show Gist options
  • Select an option

  • Save zenz/22adb486ba64f7d4c4e5b540009d7d15 to your computer and use it in GitHub Desktop.

Select an option

Save zenz/22adb486ba64f7d4c4e5b540009d7d15 to your computer and use it in GitHub Desktop.
River framework usage for house heating learning.
import socket
import time
import struct
import zlib
import signal
import binascii
import json
import pickle
import sys
import getopt
from itertools import cycle
from river import linear_model, preprocessing, compose
# ========== 配置常量 ==========
MCAST_GRP = "224.0.1.3"
MCAST_PORT = 4211
BUFFER_SIZE = 512
PACK_FORMAT = "=4sl"
MODEL_PATH = "model.pkl"
MY_NAME = "aircube-learn"
INTERVAL = 300
TCT_MIN, TCT_MAX = 35, 80
# ========== 全局变量 ==========
SOCK = None
last_eval_time = 0
last_input = None
last_target_room_temp = None
prev_cct_1 = None
prev_cct_2 = None
last_sent_tct = None
last_crt = None
model = None
# ========== 工具函数 ==========
def xor_crypt(a: str, b: str):
return "".join(chr(ord(x) ^ ord(y)) for x, y in zip(a, cycle(b)))
def unpack_data(data: bytes, secret: str):
if len(data) < 8:
return 0, 0, b"", "", ""
msgtype, datalen = struct.unpack("BB2x", data[:4])
crc1 = binascii.hexlify(data[4:8][::-1]).decode()
realdata = data[8:8 + datalen]
crc2 = hex(zlib.crc32(realdata))[2:]
decrypted = xor_crypt(realdata.decode("ascii"), secret).encode("ascii")
return msgtype, datalen, decrypted, crc1, crc2
def pack_data(msgtype: int, message: str, secret: str):
encrypted = xor_crypt(message, secret).encode("ascii")
crc = zlib.crc32(encrypted).to_bytes(4, "little")
header = struct.pack("BB2x", msgtype, len(message))
return header + crc + encrypted
def load_or_init_model():
try:
with open(MODEL_PATH, "rb") as f:
print(f"✅ 已加载模型: {MODEL_PATH}")
return pickle.load(f)
except Exception:
print(f"⚠️ 未找到模型,将新建")
return compose.Pipeline(
preprocessing.StandardScaler(),
linear_model.LinearRegression()
)
def save_model(model):
try:
with open(MODEL_PATH, "wb") as f:
pickle.dump(model, f)
print("💾 模型已保存")
except Exception as e:
print(f"❌ 模型保存失败: {e}")
def graceful_shutdown(signum, frame):
if SOCK:
SOCK.close()
print("\n🛑 程序中断,正在保存模型...")
save_model(model)
sys.exit(0)
# ========== UDP 初始化 ==========
def init_socket(port):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.bind(("0.0.0.0", port))
mreq = struct.pack(PACK_FORMAT, socket.inet_aton(MCAST_GRP), socket.INADDR_ANY)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 10)
return sock
# ========== 预测计算 ==========
def get_best_tct(features, trt, max_limit):
best, min_diff = None, float("inf")
for tct in range(TCT_MIN, max_limit + 1):
f = features.copy()
f["tct"] = tct
prediction = model.predict_one(f)
if prediction is None:
continue
diff = abs(prediction - trt) + 0.05 * tct
if diff < min_diff:
best, min_diff = tct, diff
return best
# ========== 主处理循环 ==========
def process_packet(payload, dev, secret, addr, port, trt, max_limit):
global last_eval_time, last_input, last_target_room_temp
global prev_cct_1, prev_cct_2, last_sent_tct, last_crt
odt = int(payload.get("odt", 0))
crt = round(float(payload.get("crt", 0)), 1)
cct = int(payload.get("cct", 0))
tct = int(payload.get("tct", 0))
tcm = int(payload.get("tcm", 0))
now = time.time()
if now - last_eval_time < INTERVAL:
return
print(f"\n📥 odt={odt}, crt={crt}, trt={trt}, cct={cct}")
if last_input and last_target_room_temp:
model.learn_one(last_input, crt)
delta_crt = crt - last_crt if last_crt is not None else 0
last_crt = crt
features = {
"odt": odt,
"crt": crt,
"trt": trt,
"cct": cct,
"cct_prev1": prev_cct_1 or 0,
"cct_prev2": prev_cct_2 or 0,
"delta_crt": delta_crt
}
if crt >= trt:
print(f"🌡️ 当前室温 {crt}°C ≥ 目标 {trt}°C,暂停加热")
best_tct = 0
else:
best_tct = get_best_tct(features, trt, max_limit)
print(f"🌟 推荐 tct ≈ {best_tct}°C" if best_tct else "⚠️ 模型未准备好")
should_send = (best_tct != last_sent_tct) or (best_tct in [TCT_MIN, TCT_MAX] and tct != best_tct)
if should_send:
command = {
"dev": MY_NAME, "tar": dev, "pwr": 5,
"tct": best_tct if best_tct >= TCT_MIN else 0,
}
expected_tcm = 1 if best_tct >= TCT_MIN else 0
if tcm != expected_tcm:
command["tcm"] = expected_tcm
msg = json.dumps(command, separators=(',', ':'))
SOCK.sendto(pack_data(1, msg, secret), (addr[0], port))
print(f"📤 指令发送至 {addr[0]}:{port} → {msg}")
last_sent_tct = best_tct
else:
print(f"ℹ️ 推荐温度未变(tct={best_tct}),未发送")
last_input = features.copy()
last_target_room_temp = trt
prev_cct_2, prev_cct_1 = prev_cct_1, cct
last_eval_time = now
# ========== 主函数 ==========
def main(argv):
global SOCK, model
dev, sec_key, trt, max_limit = "", "", 18.0, 60
port = MCAST_PORT
opts, _ = getopt.getopt(argv, "h:p:d:s:t:m:", ["port=", "device=", "secret=", "target=", "max="])
for opt, arg in opts:
if opt == "-h":
print("Usage: learning.py [-p <port>] -d <device> -s <secret> [-t <target>] [-m <max>]")
return
elif opt in ("-p", "--port"):
port = int(arg)
elif opt in ("-d", "--device"):
dev = arg
elif opt in ("-s", "--secret"):
sec_key = arg
elif opt in ("-t", "--target"):
trt = float(arg)
elif opt in ("-m", "--max"):
max_limit = int(arg)
if not dev or not sec_key:
print("❌ 必须指定设备和秘钥")
return
if not (TCT_MIN <= max_limit <= TCT_MAX):
print(f"❌ 最大供水温度需在 {TCT_MIN}°C ~ {TCT_MAX}°C 之间")
return
print(f"🚀 启动监听 → dev={dev}, port={port}, trt={trt}, max={max_limit}")
SOCK = init_socket(port)
model = load_or_init_model()
signal.signal(signal.SIGINT, graceful_shutdown)
while True:
try:
data, addr = SOCK.recvfrom(BUFFER_SIZE)
if not data:
continue
_, _, realdata, crc1, crc2 = unpack_data(data, sec_key)
if crc1 != crc2:
continue
decoded = realdata.decode("ascii", errors="ignore")
if dev not in decoded:
continue
json_str = decoded.replace(f'"dev":"{dev}",', "")
payload = json.loads(json_str)
process_packet(payload, dev, sec_key, addr, port, trt, max_limit)
except Exception as e:
print(f"❌ 异常: {e}")
time.sleep(0.1)
if __name__ == "__main__":
main(sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment