/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.server.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.util.Assert;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.WebServlet;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@WebServlet(asyncSupported=true)
public class HttpServletSseServerTransportProvider
extends HttpServlet
implements McpServerTransportProvider {
    private static final Logger logger = LoggerFactory.getLogger(HttpServletSseServerTransportProvider.class);
    public static final String UTF_8 = "UTF-8";
    public static final String APPLICATION_JSON = "application/json";
    public static final String FAILED_TO_SEND_ERROR_RESPONSE = "Failed to send error response: {}";
    public static final String DEFAULT_SSE_ENDPOINT = "/sse";
    public static final String MESSAGE_EVENT_TYPE = "message";
    public static final String ENDPOINT_EVENT_TYPE = "endpoint";
    public static final String DEFAULT_BASE_URL = "";
    private final ObjectMapper objectMapper;
    private final String baseUrl;
    private final String messageEndpoint;
    private final String sseEndpoint;
    private final Map<String, McpServerSession> sessions = new ConcurrentHashMap<String, McpServerSession>();
    private final AtomicBoolean isClosing = new AtomicBoolean(false);
    private McpServerSession.Factory sessionFactory;

    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) {
        this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint);
    }

    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) {
        this.objectMapper = objectMapper;
        this.baseUrl = baseUrl;
        this.messageEndpoint = messageEndpoint;
        this.sseEndpoint = sseEndpoint;
    }

    public HttpServletSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint) {
        this(objectMapper, messageEndpoint, DEFAULT_SSE_ENDPOINT);
    }

    @Override
    public void setSessionFactory(McpServerSession.Factory sessionFactory) {
        this.sessionFactory = sessionFactory;
    }

    @Override
    public Mono<Void> notifyClients(String method, Object params) {
        if (this.sessions.isEmpty()) {
            logger.debug("No active sessions to broadcast message to");
            return Mono.empty();
        }
        logger.debug("Attempting to broadcast message to {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(session -> session.sendNotification(method, params).doOnError(e -> logger.error("Failed to send message to session {}: {}", (Object)session.getId(), (Object)e.getMessage())).onErrorComplete()).then();
    }

    protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        String requestURI = request.getRequestURI();
        if (!requestURI.endsWith(this.sseEndpoint)) {
            response.sendError(404);
            return;
        }
        if (this.isClosing.get()) {
            response.sendError(503, "Server is shutting down");
            return;
        }
        response.setContentType("text/event-stream");
        response.setCharacterEncoding(UTF_8);
        response.setHeader("Cache-Control", "no-cache");
        response.setHeader("Connection", "keep-alive");
        response.setHeader("Access-Control-Allow-Origin", "*");
        String sessionId = UUID.randomUUID().toString();
        AsyncContext asyncContext = request.startAsync();
        asyncContext.setTimeout(0L);
        PrintWriter writer = response.getWriter();
        HttpServletMcpSessionTransport sessionTransport = new HttpServletMcpSessionTransport(sessionId, asyncContext, writer);
        McpServerSession session = this.sessionFactory.create(sessionTransport);
        this.sessions.put(sessionId, session);
        this.sendEvent(writer, ENDPOINT_EVENT_TYPE, this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId);
    }

    protected void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        if (this.isClosing.get()) {
            response.sendError(503, "Server is shutting down");
            return;
        }
        String requestURI = request.getRequestURI();
        if (!requestURI.endsWith(this.messageEndpoint)) {
            response.sendError(404);
            return;
        }
        String sessionId = request.getParameter("sessionId");
        if (sessionId == null) {
            response.setContentType(APPLICATION_JSON);
            response.setCharacterEncoding(UTF_8);
            response.setStatus(400);
            String jsonError = this.objectMapper.writeValueAsString(new McpError((Object)"Session ID missing in message endpoint"));
            PrintWriter writer = response.getWriter();
            writer.write(jsonError);
            writer.flush();
            return;
        }
        McpServerSession session = this.sessions.get(sessionId);
        if (session == null) {
            response.setContentType(APPLICATION_JSON);
            response.setCharacterEncoding(UTF_8);
            response.setStatus(404);
            String jsonError = this.objectMapper.writeValueAsString(new McpError((Object)("Session not found: " + sessionId)));
            PrintWriter writer = response.getWriter();
            writer.write(jsonError);
            writer.flush();
            return;
        }
        try {
            String line;
            BufferedReader reader = request.getReader();
            StringBuilder body = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                body.append(line);
            }
            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, body.toString());
            session.handle(message).block();
            response.setStatus(200);
        }
        catch (Exception e) {
            logger.error("Error processing message: {}", (Object)e.getMessage());
            try {
                McpError mcpError = new McpError((Object)e.getMessage());
                response.setContentType(APPLICATION_JSON);
                response.setCharacterEncoding(UTF_8);
                response.setStatus(500);
                String jsonError = this.objectMapper.writeValueAsString(mcpError);
                PrintWriter writer = response.getWriter();
                writer.write(jsonError);
                writer.flush();
            }
            catch (IOException ex) {
                logger.error(FAILED_TO_SEND_ERROR_RESPONSE, (Object)ex.getMessage());
                response.sendError(500, "Error processing message");
            }
        }
    }

    @Override
    public Mono<Void> closeGracefully() {
        this.isClosing.set(true);
        logger.debug("Initiating graceful shutdown with {} active sessions", (Object)this.sessions.size());
        return Flux.fromIterable(this.sessions.values()).flatMap(McpServerSession::closeGracefully).then();
    }

    private void sendEvent(PrintWriter writer, String eventType, String data) throws IOException {
        writer.write("event: " + eventType + "\n");
        writer.write("data: " + data + "\n\n");
        writer.flush();
        if (writer.checkError()) {
            throw new IOException("Client disconnected");
        }
    }

    public void destroy() {
        this.closeGracefully().block();
        super.destroy();
    }

    public static Builder builder() {
        return new Builder();
    }

    private class HttpServletMcpSessionTransport
    implements McpServerTransport {
        private final String sessionId;
        private final AsyncContext asyncContext;
        private final PrintWriter writer;

        HttpServletMcpSessionTransport(String sessionId, AsyncContext asyncContext, PrintWriter writer) {
            this.sessionId = sessionId;
            this.asyncContext = asyncContext;
            this.writer = writer;
            logger.debug("Session transport {} initialized with SSE writer", (Object)sessionId);
        }

        @Override
        public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {
            return Mono.fromRunnable(() -> {
                try {
                    String jsonText = HttpServletSseServerTransportProvider.this.objectMapper.writeValueAsString(message);
                    HttpServletSseServerTransportProvider.this.sendEvent(this.writer, HttpServletSseServerTransportProvider.MESSAGE_EVENT_TYPE, jsonText);
                    logger.debug("Message sent to session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    logger.error("Failed to send message to session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                    HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                    this.asyncContext.complete();
                }
            });
        }

        @Override
        public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
            return HttpServletSseServerTransportProvider.this.objectMapper.convertValue(data, typeRef);
        }

        @Override
        public Mono<Void> closeGracefully() {
            return Mono.fromRunnable(() -> {
                logger.debug("Closing session transport: {}", (Object)this.sessionId);
                try {
                    HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                    this.asyncContext.complete();
                    logger.debug("Successfully completed async context for session {}", (Object)this.sessionId);
                }
                catch (Exception e) {
                    logger.warn("Failed to complete async context for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
                }
            });
        }

        @Override
        public void close() {
            try {
                HttpServletSseServerTransportProvider.this.sessions.remove(this.sessionId);
                this.asyncContext.complete();
                logger.debug("Successfully completed async context for session {}", (Object)this.sessionId);
            }
            catch (Exception e) {
                logger.warn("Failed to complete async context for session {}: {}", (Object)this.sessionId, (Object)e.getMessage());
            }
        }
    }

    public static class Builder {
        private ObjectMapper objectMapper = new ObjectMapper();
        private String baseUrl = "";
        private String messageEndpoint;
        private String sseEndpoint = "/sse";

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder baseUrl(String baseUrl) {
            Assert.notNull(baseUrl, "Base URL must not be null");
            this.baseUrl = baseUrl;
            return this;
        }

        public Builder messageEndpoint(String messageEndpoint) {
            Assert.hasText(messageEndpoint, "Message endpoint must not be empty");
            this.messageEndpoint = messageEndpoint;
            return this;
        }

        public Builder sseEndpoint(String sseEndpoint) {
            Assert.hasText(sseEndpoint, "SSE endpoint must not be empty");
            this.sseEndpoint = sseEndpoint;
            return this;
        }

        public HttpServletSseServerTransportProvider build() {
            if (this.objectMapper == null) {
                throw new IllegalStateException("ObjectMapper must be set");
            }
            if (this.messageEndpoint == null) {
                throw new IllegalStateException("MessageEndpoint must be set");
            }
            return new HttpServletSseServerTransportProvider(this.objectMapper, this.baseUrl, this.messageEndpoint, this.sseEndpoint);
        }
    }
}

