Created
June 26, 2025 10:06
-
-
Save zenz/22adb486ba64f7d4c4e5b540009d7d15 to your computer and use it in GitHub Desktop.
River framework usage for house heating learning.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import socket | |
| import time | |
| import struct | |
| import zlib | |
| import signal | |
| import binascii | |
| import json | |
| import pickle | |
| import sys | |
| import getopt | |
| from itertools import cycle | |
| from river import linear_model, preprocessing, compose | |
| # ========== 配置常量 ========== | |
| MCAST_GRP = "224.0.1.3" | |
| MCAST_PORT = 4211 | |
| BUFFER_SIZE = 512 | |
| PACK_FORMAT = "=4sl" | |
| MODEL_PATH = "model.pkl" | |
| MY_NAME = "aircube-learn" | |
| INTERVAL = 300 | |
| TCT_MIN, TCT_MAX = 35, 80 | |
| # ========== 全局变量 ========== | |
| SOCK = None | |
| last_eval_time = 0 | |
| last_input = None | |
| last_target_room_temp = None | |
| prev_cct_1 = None | |
| prev_cct_2 = None | |
| last_sent_tct = None | |
| last_crt = None | |
| model = None | |
| # ========== 工具函数 ========== | |
| def xor_crypt(a: str, b: str): | |
| return "".join(chr(ord(x) ^ ord(y)) for x, y in zip(a, cycle(b))) | |
| def unpack_data(data: bytes, secret: str): | |
| if len(data) < 8: | |
| return 0, 0, b"", "", "" | |
| msgtype, datalen = struct.unpack("BB2x", data[:4]) | |
| crc1 = binascii.hexlify(data[4:8][::-1]).decode() | |
| realdata = data[8:8 + datalen] | |
| crc2 = hex(zlib.crc32(realdata))[2:] | |
| decrypted = xor_crypt(realdata.decode("ascii"), secret).encode("ascii") | |
| return msgtype, datalen, decrypted, crc1, crc2 | |
| def pack_data(msgtype: int, message: str, secret: str): | |
| encrypted = xor_crypt(message, secret).encode("ascii") | |
| crc = zlib.crc32(encrypted).to_bytes(4, "little") | |
| header = struct.pack("BB2x", msgtype, len(message)) | |
| return header + crc + encrypted | |
| def load_or_init_model(): | |
| try: | |
| with open(MODEL_PATH, "rb") as f: | |
| print(f"✅ 已加载模型: {MODEL_PATH}") | |
| return pickle.load(f) | |
| except Exception: | |
| print(f"⚠️ 未找到模型,将新建") | |
| return compose.Pipeline( | |
| preprocessing.StandardScaler(), | |
| linear_model.LinearRegression() | |
| ) | |
| def save_model(model): | |
| try: | |
| with open(MODEL_PATH, "wb") as f: | |
| pickle.dump(model, f) | |
| print("💾 模型已保存") | |
| except Exception as e: | |
| print(f"❌ 模型保存失败: {e}") | |
| def graceful_shutdown(signum, frame): | |
| if SOCK: | |
| SOCK.close() | |
| print("\n🛑 程序中断,正在保存模型...") | |
| save_model(model) | |
| sys.exit(0) | |
| # ========== UDP 初始化 ========== | |
| def init_socket(port): | |
| sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) | |
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
| sock.bind(("0.0.0.0", port)) | |
| mreq = struct.pack(PACK_FORMAT, socket.inet_aton(MCAST_GRP), socket.INADDR_ANY) | |
| sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq) | |
| sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 10) | |
| return sock | |
| # ========== 预测计算 ========== | |
| def get_best_tct(features, trt, max_limit): | |
| best, min_diff = None, float("inf") | |
| for tct in range(TCT_MIN, max_limit + 1): | |
| f = features.copy() | |
| f["tct"] = tct | |
| prediction = model.predict_one(f) | |
| if prediction is None: | |
| continue | |
| diff = abs(prediction - trt) + 0.05 * tct | |
| if diff < min_diff: | |
| best, min_diff = tct, diff | |
| return best | |
| # ========== 主处理循环 ========== | |
| def process_packet(payload, dev, secret, addr, port, trt, max_limit): | |
| global last_eval_time, last_input, last_target_room_temp | |
| global prev_cct_1, prev_cct_2, last_sent_tct, last_crt | |
| odt = int(payload.get("odt", 0)) | |
| crt = round(float(payload.get("crt", 0)), 1) | |
| cct = int(payload.get("cct", 0)) | |
| tct = int(payload.get("tct", 0)) | |
| tcm = int(payload.get("tcm", 0)) | |
| now = time.time() | |
| if now - last_eval_time < INTERVAL: | |
| return | |
| print(f"\n📥 odt={odt}, crt={crt}, trt={trt}, cct={cct}") | |
| if last_input and last_target_room_temp: | |
| model.learn_one(last_input, crt) | |
| delta_crt = crt - last_crt if last_crt is not None else 0 | |
| last_crt = crt | |
| features = { | |
| "odt": odt, | |
| "crt": crt, | |
| "trt": trt, | |
| "cct": cct, | |
| "cct_prev1": prev_cct_1 or 0, | |
| "cct_prev2": prev_cct_2 or 0, | |
| "delta_crt": delta_crt | |
| } | |
| if crt >= trt: | |
| print(f"🌡️ 当前室温 {crt}°C ≥ 目标 {trt}°C,暂停加热") | |
| best_tct = 0 | |
| else: | |
| best_tct = get_best_tct(features, trt, max_limit) | |
| print(f"🌟 推荐 tct ≈ {best_tct}°C" if best_tct else "⚠️ 模型未准备好") | |
| should_send = (best_tct != last_sent_tct) or (best_tct in [TCT_MIN, TCT_MAX] and tct != best_tct) | |
| if should_send: | |
| command = { | |
| "dev": MY_NAME, "tar": dev, "pwr": 5, | |
| "tct": best_tct if best_tct >= TCT_MIN else 0, | |
| } | |
| expected_tcm = 1 if best_tct >= TCT_MIN else 0 | |
| if tcm != expected_tcm: | |
| command["tcm"] = expected_tcm | |
| msg = json.dumps(command, separators=(',', ':')) | |
| SOCK.sendto(pack_data(1, msg, secret), (addr[0], port)) | |
| print(f"📤 指令发送至 {addr[0]}:{port} → {msg}") | |
| last_sent_tct = best_tct | |
| else: | |
| print(f"ℹ️ 推荐温度未变(tct={best_tct}),未发送") | |
| last_input = features.copy() | |
| last_target_room_temp = trt | |
| prev_cct_2, prev_cct_1 = prev_cct_1, cct | |
| last_eval_time = now | |
| # ========== 主函数 ========== | |
| def main(argv): | |
| global SOCK, model | |
| dev, sec_key, trt, max_limit = "", "", 18.0, 60 | |
| port = MCAST_PORT | |
| opts, _ = getopt.getopt(argv, "h:p:d:s:t:m:", ["port=", "device=", "secret=", "target=", "max="]) | |
| for opt, arg in opts: | |
| if opt == "-h": | |
| print("Usage: learning.py [-p <port>] -d <device> -s <secret> [-t <target>] [-m <max>]") | |
| return | |
| elif opt in ("-p", "--port"): | |
| port = int(arg) | |
| elif opt in ("-d", "--device"): | |
| dev = arg | |
| elif opt in ("-s", "--secret"): | |
| sec_key = arg | |
| elif opt in ("-t", "--target"): | |
| trt = float(arg) | |
| elif opt in ("-m", "--max"): | |
| max_limit = int(arg) | |
| if not dev or not sec_key: | |
| print("❌ 必须指定设备和秘钥") | |
| return | |
| if not (TCT_MIN <= max_limit <= TCT_MAX): | |
| print(f"❌ 最大供水温度需在 {TCT_MIN}°C ~ {TCT_MAX}°C 之间") | |
| return | |
| print(f"🚀 启动监听 → dev={dev}, port={port}, trt={trt}, max={max_limit}") | |
| SOCK = init_socket(port) | |
| model = load_or_init_model() | |
| signal.signal(signal.SIGINT, graceful_shutdown) | |
| while True: | |
| try: | |
| data, addr = SOCK.recvfrom(BUFFER_SIZE) | |
| if not data: | |
| continue | |
| _, _, realdata, crc1, crc2 = unpack_data(data, sec_key) | |
| if crc1 != crc2: | |
| continue | |
| decoded = realdata.decode("ascii", errors="ignore") | |
| if dev not in decoded: | |
| continue | |
| json_str = decoded.replace(f'"dev":"{dev}",', "") | |
| payload = json.loads(json_str) | |
| process_packet(payload, dev, sec_key, addr, port, trt, max_limit) | |
| except Exception as e: | |
| print(f"❌ 异常: {e}") | |
| time.sleep(0.1) | |
| if __name__ == "__main__": | |
| main(sys.argv[1:]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment