qinfengge

qinfengge

醉后不知天在水,满船清梦压星河
github
email
telegram

spring AI (4) Continuous Dialogue

In previous articles, what we implemented were simple calls that could only achieve one dialogue. This is neither realistic nor elegant; who has a conversation that consists of just one sentence? Except for the one below

image

If AI can only have one conversation, then we can only say AI, please go cool off for a while.
So how do we enable the model to have continuous conversations? The key is memory—to remember the user's questions while also remembering the model's own outputs. This way, the model can derive reasonable answers based on the previous context.

Prompt#

Do you remember what was said in the first chapter?

In fact, there are many types of Prompts, and they can be played in various ways. They are not just keywords; they are also key to multi-turn conversations.

In creating a Prompt, we can see that it can accept two types of parameters: one is a single message, and the other is a collection of messages.

image
The MessageType in the message has the following four types:

image

Doesn't this match up?

    USER("user"),  // User's input

	ASSISTANT("assistant"), // Model's output

	SYSTEM("system"), // Model's persona

	FUNCTION("function"); // Function

Imagine this: in reality, when you have a conversation with someone, you say one sentence, and they respond with another, and their response must correspond to the previous context; otherwise, it would be nonsensical. Therefore, the key to continuous dialogue in the model is the same: to pass the previous context to the model during each conversation, allowing it to understand the corresponding contextual relationship. This is the role of the message collection.

Implementation#

The simple principle has been explained, so let's get started. However, we cannot forget the efforts made in the previous articles; we need to integrate them. So next, we will implement a complete function that includes streaming output, function calls, and continuous dialogue.

First, initialize the client.

    private static final String BASEURL = "https://xxx";

    private static final String TOKEN = "sk-xxxx";

    /**
     * Create OpenAiChatClient
     * @return OpenAiChatClient
     */
    private static OpenAiChatClient getClient(){
        OpenAiApi openAiApi = new OpenAiApi(BASEURL, TOKEN);
        return new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder()
                .withModel("gpt-3.5-turbo-1106")
                .withTemperature(0.8F)
                .build());
    }

It is important to note that some older models from OpenAI do not support function calls and streaming output, and some older models have a maximum token limit of only 4K. Creating errors may lead to 400 BAD REQUEST.

Next, save historical information.
First, create a Map.

private static Map<String, List<Message>> chatMessage = new ConcurrentHashMap<>();

Here, the key of the Map corresponds to the session ID, and the value is the historical messages. Note that the session must have a corresponding unique ID; otherwise, it will get mixed up.
Then, during each conversation, pass in the session ID and the user's input, placing the corresponding input into the message collection.

/**
     * Return prompt
     * @param message User's input message
     * @return Prompt
     */
    private List<Message> getMessages(String id, String message) {
        String systemPrompt = "{prompt}";
        SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);

        Message userMessage = new UserMessage(message);

        Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "you are a helpful AI assistant"));

        List<Message> messages = chatMessage.get(id);

        // If no messages are retrieved, create new messages and add the system prompt and user input to the message list
        if (messages == null){
            messages = new ArrayList<>();
            messages.add(systemMessage);
            messages.add(userMessage);
        } else {
            messages.add(userMessage);
        }

        return messages;
    }

If it is the first round of dialogue and the message list is empty, it will also put in the systemMessage, effectively initializing a persona for the model.

Then, create functions.

/**
     * Initialize function calls
     * @return ChatOptions
     */
    private ChatOptions initFunc(){
        return OpenAiChatOptions.builder().withFunctionCallbacks(List.of(
                FunctionCallbackWrapper.builder(new MockWeatherService()).withName("weather").withDescription("Get the weather in location").build(),
                FunctionCallbackWrapper.builder(new WbHotService()).withName("wbHot").withDescription("Get the hot list of Weibo").build(),
                FunctionCallbackWrapper.builder(new TodayNews()).withName("todayNews").withDescription("60s watch world news").build(),
                FunctionCallbackWrapper.builder(new DailyEnglishFunc()).withName("dailyEnglish").withDescription("A daily inspirational sentence in English").build())).build();
    }

For information related to functions, please refer to Chapter 3.

Finally, it's time for the output.
Here, it should be noted that since the final implementation effect is a webpage, we used the server-side active push feature, specifically SSE. For an introduction to SSE, you can refer to the previous blog post on message pushing.
In summary, here is an SSE utility class.

@Component
@Slf4j
public class SseEmitterUtils {
    /**
     * Current connection count
     */
    private static AtomicInteger count = new AtomicInteger(0);

    /**
     * Store SseEmitter information
     */
    private static Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();

    /**
     * Create user connection and return SseEmitter
     * @param key userId
     * @return SseEmitter
     */
    public static SseEmitter connect(String key) {
        if (sseEmitterMap.containsKey(key)) {
            return sseEmitterMap.get(key);
        }

        try {
            // Set timeout, 0 means no expiration. Default is 30 seconds
            SseEmitter sseEmitter = new SseEmitter(0L);
            // Register callbacks
            sseEmitter.onCompletion(completionCallBack(key));
            sseEmitter.onError(errorCallBack(key));
            sseEmitter.onTimeout(timeoutCallBack(key));
            sseEmitterMap.put(key, sseEmitter);
            // Increment count
            count.getAndIncrement();
            return sseEmitter;
        } catch (Exception e) {
            log.info("Error creating new SSE connection, current connection Key: {}", key);
        }
        return null;
    }

    /**
     * Send message to specified user
     * @param key userId
     * @param message Message content
     */
    public static void sendMessage(String key, String message) {
        if (sseEmitterMap.containsKey(key)) {
            try {
                sseEmitterMap.get(key).send(message);
            } catch (IOException e) {
                log.error("User[{}] push exception:{}", key, e.getMessage());
                remove(key);
            }
        }
    }

    /**
     * Publish message to the same group of people, requirement: key + groupId
     * @param groupId Group id
     * @param message Message content
     */
    public static void groupSendMessage(String groupId, String message) {
        if (!CollectionUtils.isEmpty(sseEmitterMap)) {
            sseEmitterMap.forEach((k, v) -> {
                try {
                    if (k.startsWith(groupId)) {
                        v.send(message, MediaType.APPLICATION_JSON);
                    }
                } catch (IOException e) {
                    log.error("User[{}] push exception:{}", k, e.getMessage());
                    remove(k);
                }
            });
        }
    }

    /**
     * Broadcast group message
     * @param message Message content
     */
    public static void batchSendMessage(String message) {
        sseEmitterMap.forEach((k, v) -> {
            try {
                v.send(message, MediaType.APPLICATION_JSON);
            } catch (IOException e) {
                log.error("User[{}] push exception:{}", k, e.getMessage());
                remove(k);
            }
        });
    }

    /**
     * Group message
     * @param message Message content
     * @param ids User id collection
     */
    public static void batchSendMessage(String message, Set<String> ids) {
        ids.forEach(userId -> sendMessage(userId, message));
    }

    /**
     * Remove connection
     * @param key userId
     */
    public static void remove(String key) {
        sseEmitterMap.remove(key);
        // Decrement count
        count.getAndDecrement();
        log.info("Removed connection: {}", key);
    }

    /**
     * Get current connection information
     * @return Map
     */
    public static List<String> getIds() {
        return new ArrayList<>(sseEmitterMap.keySet());
    }

    /**
     * Get current connection count
     * @return int
     */
    public static int getCount() {
        return count.intValue();
    }

    private static Runnable completionCallBack(String key) {
        return () -> {
            log.info("Connection ended: {}", key);
            remove(key);
        };
    }

    private static Runnable timeoutCallBack(String key) {
        return () -> {
            log.info("Connection timed out: {}", key);
            remove(key);
        };
    }

    private static Consumer<Throwable> errorCallBack(String key) {
        return throwable -> {
            log.info("Connection exception: {}", key);
            remove(key);
        };
    }
}

Can we start chatting now? No, if it's a webpage, we still need to think about the specific implementation. The mainstream AI model's webpage mainly has these two aspects:

  1. Quick questioning, users can directly ask questions on the homepage.
  2. Save conversation information, each round of dialogue is unique, and users can return to any specific conversation at any time.

The specific implementation is:

  1. Create an interface that users call when accessing the homepage to return the session ID.
  2. The subsequent user inputs are bound to the session ID returned in the first step, unless the browser is refreshed or a new session is created.

So the first step is:

 /**
     * Create connection
     */
    @SneakyThrows
    @GetMapping("/init/{message}")
    public String init() {
        return String.valueOf(UUID.randomUUID());
    }

Directly return the UUID to the front end.

Finally, with the session ID, we can bind it to the session and output it.

@GetMapping("chat/{id}/{message}")
    public SseEmitter chat(@PathVariable String id, @PathVariable String message, HttpServletResponse response) {

        response.setHeader("Content-type", "text/html;charset=UTF-8");
        response.setCharacterEncoding("UTF-8");

        OpenAiChatClient client = getClient();
        SseEmitter emitter = SseEmitterUtils.connect(id);
        List<Message> messages = getMessages(id, message);
        System.err.println("chatMessage size: " + messages.size());
        System.err.println("chatMessage: " + chatMessage);

        if (messages.size() > MAX_MESSAGE){
            SseEmitterUtils.sendMessage(id, "Too many conversation rounds, please try again later🤔");
        }else {
            // Get the model's output stream
            Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));

            // Send the messages in the stream using SSE
            Mono<String> result = stream
                    .flatMap(it -> {
                        StringBuilder sb = new StringBuilder();
                        String content = it.getResult().getOutput().getContent();
                        Optional.ofNullable(content).ifPresent(r -> {
                            SseEmitterUtils.sendMessage(id, content);
                            sb.append(content);
                        });
                        return Mono.just(sb.toString());
                    })
                    // Concatenate messages into a string
                    .reduce((a, b) -> a + b)
                    .defaultIfEmpty("");

            // Store the messages in chatMessage as AssistantMessage
            result.subscribe(finalContent -> messages.add(new AssistantMessage(finalContent)));

            // Store the messages in chatMessage
            chatMessage.put(id, messages);

        }
        return emitter;

    }

First, use response to set the return encoding to UTF-8 to prevent garbled text.
Then use SseEmitterUtils to connect to the corresponding session.
Next, use getMessages to return the historical messages for the corresponding session.
Then, use MAX_MESSAGE to judge the number of conversation rounds; if it exceeds this value, the model output will not be called again, mainly to reduce costs.

private static final Integer MAX_MESSAGE = 10;

Here it is set to 10 rounds, but it is actually 5 rounds of dialogue because it is judged by the size of the historical messages, which includes both user inputs and model outputs, so it should be divided by 2.

chatMessage: {e2578f9e-8d71-4531-a6af-400a80fb6569=[SystemMessage{content='you are a helpful AI assistant', properties={}, messageType=SYSTEM}, UserMessage{content='Hello', properties={}, messageType=USER}, AssistantMessage{content='Hello! How can I assist you?', properties={}, messageType=ASSISTANT}, UserMessage{content='Who are you?', properties={}, messageType=USER}]}

Finally, it's the model's output.

Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));

Using the stream, the Prompt is passed in with historical messages and functions.
After obtaining the output stream, the content from the stream is sent to the corresponding session using SseEmitterUtils.sendMessage(id, content);.

And the last step is to also put the model's output into the historical messages, so the model knows what it has already answered and does not need to answer again.
If it is not added, the model will answer all previous user inputs.
For example, in the first round, asking "Introduce Hangzhou," the AI's answer is normal.
In the second round, asking "What famous attractions are there in Hangzhou," the AI does not know if it has answered the previous question, so it tends to answer both questions simultaneously, i.e., "Introduce Hangzhou, what famous attractions are there in Hangzhou."
The same applies to the third and fourth rounds.

So how do we get the complete output from the stream?
Initially, I used stream.subscribe and StringBuilder to append the content from the stream to sb, but sb was always null. After asking Claude, I learned that Flux is asynchronous, and finally used Mono for processing.

In this code, we create a new StringBuilder instance sb in the flatMap callback function. Then, we append the content of each response to sb and return a Mono that emits the result of sb.toString().
Next, we use the reduce operator to merge all Mono into one Mono. The parameter of reduce is a merging function that combines the previous value and the current value into a new value. Here, we use string concatenation to combine all response contents.
Finally, we subscribe to this Mono and in its callback function, we add the final content to messages. If there are no responses, we use defaultIfEmpty("") to ensure that an empty string is emitted instead of null.
In this way, we can correctly obtain all the content of the streaming response and add it to messages.

Finally, the task is complete! 😎

Oh, I almost forgot, there's still the front end, but since I'm not very familiar with front-end development, I chose to use openui to help me write the page and styles, and then used Claude to help me write the interface logic. Thus, I finally got this:

<!doctype html>
<html>

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <script src="https://cdn.tailwindcss.com"></script>
</head>

<body class="bg-zinc-100 dark:bg-zinc-800 min-h-screen p-4">
    <div class="flex flex-col h-full">
        <div id="messages" class="flex-1 overflow-y-auto p-4 space-y-4">
            <div class="flex items-end">
                <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
                <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs">Hi~(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄</div>
            </div>
        </div>
        <div class="p-2">
            <input type="text" id="messageInput" placeholder="Please enter a message..."
                class="w-full p-2 rounded-lg border-2 border-zinc-300 dark:border-zinc-600 focus:outline-none focus:border-blue-500 dark:focus:border-blue-400">
            <button onclick="sendMessage()"
                class="mt-2 w-full bg-blue-500 hover:bg-blue-600 dark:bg-blue-600 dark:hover:bg-blue-700 text-white p-2 rounded-lg">Send</button>
        </div>
    </div>
    <script>
        let sessionId; // Used to store session ID

        // Send HTTP request and handle response
        function sendHTTPRequest(url, method = 'GET', body = null) {
            return new Promise((resolve, reject) => {
                const xhr = new XMLHttpRequest();
                xhr.open(method, url, true);
                xhr.onload = () => {
                    if (xhr.status >= 200 && xhr.status < 300) {
                        resolve(xhr.response);
                    } else {
                        reject(xhr.statusText);
                    }
                };
                xhr.onerror = () => reject(xhr.statusText);
                if (body) {
                    xhr.setRequestHeader('Content-Type', 'application/json');
                    xhr.send(JSON.stringify(body));
                } else {
                    xhr.send();
                }
            });
        }

        // Handle SSE stream returned by the server
        function handleSSEStream(stream) {
            console.log('Stream started');
            console.log(stream);
            const messagesContainer = document.getElementById('messages');
            const responseDiv = document.createElement('div');
            responseDiv.className = 'flex items-end';
            responseDiv.innerHTML = `
    <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
    <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs"></div>
  `;
            messagesContainer.appendChild(responseDiv);

            const messageContentDiv = responseDiv.querySelector('div');

            // Listen for 'message' events, triggered when the backend sends new data
            stream.onmessage = function (event) {
                const data = event.data;
                console.log('Received data:', data);
                messageContentDiv.textContent += data;
                messagesContainer.scrollTop = messagesContainer.scrollHeight;
            };
        }

        // Send message
        function sendMessage() {
            const input = document.getElementById('messageInput');
            const message = input.value.trim();
            if (message) {
                const messagesContainer = document.getElementById('messages');
                const newMessageDiv = document.createElement('div');
                newMessageDiv.className = 'flex items-end justify-end';
                newMessageDiv.innerHTML = `
          <div class="mr-2 p-2 bg-green-200 dark:bg-green-700 rounded-lg max-w-xs">
            ${message}
          </div>
          <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
        `;
                messagesContainer.appendChild(newMessageDiv);
                input.value = '';
                messagesContainer.scrollTop = messagesContainer.scrollHeight;

                // When sending the message for the first time, send init request to get session ID
                if (!this.sessionId) {
                    console.log('init');
                    sendHTTPRequest(`http://127.0.0.1:8868/pro/init/${message}`, 'GET')
                        .then(response => {
                            this.sessionId = response; // Store session ID
                            return handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                        });

                } else {
                    // Subsequent requests are sent directly to the chat interface
                    handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                }
            }
        }
    </script>
</body>

</html>

Final Effect#

image

PS: Actually, the front end can be further optimized, such as displaying historical conversations, using markdown to render outputs, etc. If you're interested, you can use AI tools to modify it.

Spring AI Continuous Dialogue

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.