Created
January 6, 2026 03:14
-
-
Save BHznJNs/cf11352f210084841e28b9fb2ff76f35 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import asyncio | |
| from contextlib import AsyncExitStack | |
| from typing import Any | |
| from urllib.parse import parse_qs, urlparse | |
| import httpx | |
| from mcp import ClientSession, StdioServerParameters | |
| from mcp.client.auth import OAuthClientProvider, TokenStorage | |
| from mcp.client.stdio import stdio_client | |
| from mcp.client.streamable_http import streamable_http_client | |
| from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken | |
| from mcp.types import CallToolResult, Tool | |
| from mcp.types import Content as ToolResult | |
| from pydantic import AnyUrl | |
| class InMemoryTokenStorage(TokenStorage): | |
| """Demo In-memory token storage implementation.""" | |
| def __init__(self): | |
| self.tokens: OAuthToken | None = None | |
| self.client_info: OAuthClientInformationFull | None = None | |
| async def get_tokens(self) -> OAuthToken | None: | |
| """Get stored tokens.""" | |
| return self.tokens | |
| async def set_tokens(self, tokens: OAuthToken) -> None: | |
| """Store tokens.""" | |
| self.tokens = tokens | |
| async def get_client_info(self) -> OAuthClientInformationFull | None: | |
| """Get stored client information.""" | |
| return self.client_info | |
| async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: | |
| """Store client information.""" | |
| self.client_info = client_info | |
| class LocalMcpClient: | |
| def __init__(self, params: StdioServerParameters): | |
| self._params: StdioServerParameters = params | |
| self._session: ClientSession | None = None | |
| self._exit_stack: AsyncExitStack | None = None | |
| async def connect(self) -> None: | |
| self._exit_stack = AsyncExitStack() | |
| try: | |
| read_stream, write_stream = await self._exit_stack.enter_async_context( | |
| stdio_client(self._params) | |
| ) | |
| self._session = await self._exit_stack.enter_async_context( | |
| ClientSession(read_stream, write_stream) | |
| ) | |
| _ = await self._session.initialize() | |
| except Exception as e: | |
| await self.disconnect() | |
| raise | |
| async def list_available_tools(self) -> list[Tool]: | |
| if not self._session: | |
| raise RuntimeError("MCP Session 未建立,请先调用 connect()") | |
| result = await self._session.list_tools() | |
| return result.tools | |
| async def call( | |
| self, tool_name: str, arguments: dict[str, Any] | None = None | |
| ) -> list[ToolResult]: | |
| if not self._session: | |
| raise RuntimeError("MCP Session 未建立") | |
| response: CallToolResult = await self._session.call_tool( | |
| tool_name, arguments=arguments | |
| ) | |
| response.content | |
| return response.content | |
| async def disconnect(self) -> None: | |
| if self._exit_stack: | |
| await self._exit_stack.aclose() | |
| self._session = None | |
| self._exit_stack = None | |
| class RemoteMcpClient: | |
| def __init__(self, url: str, client: httpx.AsyncClient | None = None): | |
| self._url: str = url | |
| self._custom_client: httpx.AsyncClient | None = client | |
| self._session: ClientSession | None = None | |
| self._exit_stack: AsyncExitStack | None = None | |
| async def connect(self) -> None: | |
| self._exit_stack = AsyncExitStack() | |
| try: | |
| read_stream, write_stream, _ = await self._exit_stack.enter_async_context( | |
| streamable_http_client(self._url, http_client=self._custom_client) | |
| ) | |
| self._session = await self._exit_stack.enter_async_context( | |
| ClientSession(read_stream, write_stream) | |
| ) | |
| _ = await self._session.initialize() | |
| except Exception as e: | |
| await self.disconnect() | |
| raise | |
| async def list_available_tools(self) -> list[Tool]: | |
| if not self._session: | |
| raise RuntimeError("MCP Session 未建立,请先调用 connect()") | |
| result = await self._session.list_tools() | |
| return result.tools | |
| async def call( | |
| self, tool_name: str, arguments: dict[str, Any] | None = None | |
| ) -> list[ToolResult]: | |
| if not self._session: | |
| raise RuntimeError("MCP Session 未建立") | |
| response: CallToolResult = await self._session.call_tool( | |
| tool_name, arguments=arguments | |
| ) | |
| response.content | |
| return response.content | |
| async def disconnect(self) -> None: | |
| if self._exit_stack: | |
| await self._exit_stack.aclose() | |
| self._session = None | |
| self._exit_stack = None | |
| async def handle_redirect(auth_url: str) -> None: | |
| print(f"Visit: {auth_url}") | |
| async def handle_callback() -> tuple[str, str | None]: | |
| callback_url = input("Paste callback URL: ") | |
| params = parse_qs(urlparse(callback_url).query) | |
| return params["code"][0], params.get("state", [None])[0] | |
| async def main(): | |
| oauth_auth = OAuthClientProvider( | |
| server_url="https://example-server.modelcontextprotocol.io/mcp", | |
| client_metadata=OAuthClientMetadata( | |
| client_name="Example MCP Client", | |
| redirect_uris=[AnyUrl("http://localhost:3000/callback")], | |
| grant_types=["authorization_code", "refresh_token"], | |
| response_types=["code"], | |
| scope="user", | |
| token_endpoint_auth_method="none", | |
| ), | |
| storage=InMemoryTokenStorage(), | |
| redirect_handler=handle_redirect, | |
| callback_handler=handle_callback, | |
| ) | |
| custom_client = httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) | |
| client = RemoteMcpClient( | |
| "https://example-server.modelcontextprotocol.io/mcp", client=custom_client | |
| ) | |
| await client.connect() | |
| tools = await client.list_available_tools() | |
| print(f"发现工具: {[t.name for t in tools]}") | |
| print("正在调用 'get_greeting'...") | |
| result = await client.call("echo", arguments={"message": "Hello, World!"}) | |
| # 打印结果 | |
| for content in result: | |
| if content.type == "text": | |
| print(f"Server 回应: {content.text}") | |
| await client.disconnect() | |
| if __name__ == "__main__": | |
| asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment