Spring AI Java步伐员的AI之Spring AI(二)

打印 上一主题 下一主题

主题 1520|帖子 1520|积分 4560

历史Spring AI文章

Spring AI Java步伐员的AI之Spring AI(一)
一丶Spring AI 函数调用

定义工具函数Function

在Spring AI中,如果一个Bean实现了Function接口,那么它就是一个工具函数,并且通过@Description注解可以形貌该工具的作用是什么,如果工具有需要吸收参数,也可以通过@Schema注解来对参数进行定义,好比以下工具是用来获取指定地点的当前时间的,并且address参数用来吸收详细的地点:
  1. package com.qjc.demo.service;
  2. import io.swagger.v3.oas.annotations.media.Schema;
  3. import org.springframework.context.annotation.Description;
  4. import org.springframework.stereotype.Component;
  5. import java.time.LocalDateTime;
  6. import java.util.function.Function;
  7. /***
  8. * @projectName spring-ai-demo
  9. * @packageName com.qjc.demo.service
  10. * @author qjc
  11. * @description TODO
  12. * @Email qjc1024@aliyun.com
  13. * @date 2024-10-16 09:50
  14. **/
  15. @Component
  16. @Description("获取指定地点的当前时间")
  17. public class DateService implements Function<DateService.Request, DateService.Response> {
  18.     public record Request(@Schema(description = "地点") String address) { }
  19.     public record Response(String date) { }
  20.     @Override
  21.     public Response apply(Request request) {
  22.         System.out.println(request.address);
  23.         return new Response(String.format("%s的当前时间是%s", request.address, LocalDateTime.now()));
  24.     }
  25. }
复制代码
工具函数调用

当向大模型提问时,需要指定所要调用的工具函数,使用OpenAiChatOptions指定对应的beanName就可以了,好比:
  1. @GetMapping("/function")
  2. public String function(@RequestParam String message) {
  3.     Prompt prompt = new Prompt(message, OpenAiChatOptions.builder().withFunction("dateService").build());
  4.     Generation generation = chatClient.call(prompt).getResult();
  5.     return (generation != null) ? generation.getOutput().getContent() : "";
  6. }
复制代码
FunctionCallback工具函数

还可以直接在提问时直接定义并调用工具,好比:
  1. @GetMapping("/functionCallback")
  2. public String functionCallback(@RequestParam String message) {
  3.     Prompt prompt = new Prompt(message, OpenAiChatOptions.builder().withFunctionCallbacks(
  4.         List.of(FunctionCallbackWrapper.builder(new DateService())
  5.                 .withName("dateService")
  6.                 .withDescription("获取指定地点的当前时间").build())
  7.     ).build());
  8.     Generation generation = chatClient.call(prompt).getResult();
  9.     return (generation != null) ? generation.getOutput().getContent() : "";
  10. }
复制代码
通过这种方式,就不需要将DateService定义为Bean了,固然这样定义的工具只能functionCallback接口单独使用了,而定义Bean则可以让多个接口共享使用。
不过有时候,大模型给你的答案或工具参数可能是英文的
那么可以使用SystemMessage来设置系统提示词,好比:
  1. @GetMapping("/functionCallback")
  2. public String functionCallback(@RequestParam String message) {
  3.     SystemMessage systemMessage = new SystemMessage("请用中文回答我");
  4.     UserMessage userMessage = new UserMessage(message);
  5.     Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder().withFunctionCallbacks(
  6.         List.of(FunctionCallbackWrapper.builder(new DateService())
  7.                 .withName("dateService")
  8.                 .withDescription("获取指定地点的当前时间").build())
  9.     ).build());
  10.     Generation generation = chatClient.call(prompt).getResult();
  11.     return (generation != null) ? generation.getOutput().getContent() : "";
  12. }
复制代码
这样就能控制答案了。
二丶 Spring AI 函数调用源码解析

在OpenAiChatClient的call()方法中,会进行:

  • 哀求的处理
  • 工具的调用
  • 相应的处理
  • 重试机制
好比call()方法的大要代码为:
  1. @Override
  2. public ChatResponse call(Prompt prompt) {
  3.     // 请求处理
  4.     ChatCompletionRequest request = createRequest(prompt, false);
  5.     // 重试机制
  6.     return this.retryTemplate.execute(ctx -> {
  7.    
  8.         // 请求调用
  9.         ResponseEntity<ChatCompletion> completionEntity = this.callWithFunctionSupport(request);
  10.    
  11.         // 返回响应
  12.         return new ChatResponse(...);
  13.     });
  14. }
复制代码
哀求处理

哀求处理焦点是把Prompt对象转换成ChatCompletionRequest对象,包罗Prompt中设置的SystemMessage、UserMessage和工具函数。
如果接纳Bean的方式来使用工具函数,其底层实在对应的仍然是FunctionCallback,在OpenAiAutoConfiguration主动设置中,定义了一个FunctionCallbackContext的Bean,该Bean提供了一个getFunctionCallback()方法,用来生成beanName对应的FunctionCallback对象,源码为:
  1. public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable String defaultDescription) {
  2.     // 获取Bean类型
  3.     Type beanType = FunctionContextUtils.findType(this.applicationContext.getBeanFactory(), beanName);
  4.     if (beanType == null) {
  5.         throw new IllegalArgumentException(
  6.             "Functional bean with name: " + beanName + " does not exist in the context.");
  7.     }
  8.     // Bean类型必须是Function类型
  9.     if (!Function.class.isAssignableFrom(FunctionTypeUtils.getRawType(beanType))) {
  10.         throw new IllegalArgumentException(
  11.             "Function call Bean must be of type Function. Found: " + beanType.getTypeName());
  12.     }
  13.     // 获取Function的第一个泛型的类型,比如Request
  14.     Type functionInputType = TypeResolverHelper.getFunctionArgumentType(beanType, 0);
  15.     Class<?> functionInputClass = FunctionTypeUtils.getRawType(functionInputType);
  16.    
  17.     String functionName = beanName;
  18.     String functionDescription = defaultDescription;
  19.     if (!StringUtils.hasText(functionDescription)) {
  20.         // 获取@Description设置的描述信息
  21.         // Look for a Description annotation on the bean
  22.         Description descriptionAnnotation = applicationContext.findAnnotationOnBean(beanName, Description.class);
  23.         if (descriptionAnnotation != null) {
  24.             functionDescription = descriptionAnnotation.value();
  25.         }
  26.         // 获取Request参数前的@JsonClassDescription设置的描述信息
  27.         if (!StringUtils.hasText(functionDescription)) {
  28.             // Look for a JsonClassDescription annotation on the input class
  29.             JsonClassDescription jsonClassDescriptionAnnotation = functionInputClass
  30.             .getAnnotation(JsonClassDescription.class);
  31.             if (jsonClassDescriptionAnnotation != null) {
  32.                 functionDescription = jsonClassDescriptionAnnotation.value();
  33.             }
  34.         }
  35.         if (!StringUtils.hasText(functionDescription)) {
  36.             throw new IllegalStateException("Could not determine function description."
  37.                                             + "Please provide a description either as a default parameter, via @Description annotation on the bean "
  38.                                             + "or @JsonClassDescription annotation on the input class.");
  39.         }
  40.     }
  41.     // 获取Bean对象
  42.     Object bean = this.applicationContext.getBean(beanName);
  43.     // 构建为FunctionCallback对象
  44.     if (bean instanceof Function<?, ?> function) {
  45.         return FunctionCallbackWrapper.builder(function)
  46.         .withName(functionName)
  47.         .withSchemaType(this.schemaType)
  48.         .withDescription(functionDescription)
  49.         .withInputType(functionInputClass)
  50.         .build();
  51.     }
  52.     else {
  53.         throw new IllegalArgumentException("Bean must be of type Function");
  54.     }
  55. }
复制代码
以上代码的焦点逻辑为:

  • 获取Bean类型
  • 获取Function的第一个泛型的类型,好比Request
  • 获取@Description设置的形貌信息
  • 构造FunctionCallback对象
在OpenAiChatClient就会注入FunctionCallbackContext这个Bean对象,从而使得OpenAiChatClient可以通过Prompt中指定的beanName获取到对应的FunctionCallback对象。
以是,在createRequest()方法中,就可以得到从FunctionCallbackContext找到的或者直接在Prompt对象中设置的FunctionCallback对象,然后将FunctionCallback对象转成OpenAiApi.FunctionTool对象,最终将FunctionTool设置到ChatCompletionRequest中。
哀求调用

哀求调用源码如下:
  1. protected Resp callWithFunctionSupport(Req request) {
  2.     Resp response = this.doChatCompletion(request);
  3.     return this.handleFunctionCallOrReturn(request, response);
  4. }
复制代码

  • 先发送哀求得到相应
  • 解析相应是否需要调用工具还是直接返回
doChatCompletion()方法比较简单,就是直接把哀求发送给OpenAi,紧张的是handleFunctionCallOrReturn()方法。
handleFunctionCallOrReturn()方法需要解析相应,好比判定OpenAi返回的相应中是否需要调用工具,好比:
  1. if (!this.isToolFunctionCall(response)) {
  2.     return response;
  3. }
复制代码
OpenAi中,如果一个相应的finishReason为TOOL_CALLS则表示,当前相应实在是OpenAi的一个工具调用哀求。
然后就去执行工具:
  1. protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
  2.                                                             ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
  3.     // 遍历每个要调用的工具
  4.     // Every tool-call item requires a separate function call and a response (TOOL)
  5.     // message.
  6.     for (ToolCall toolCall : responseMessage.toolCalls()) {
  7.         // 工具名和参数
  8.         var functionName = toolCall.function().name();
  9.         String functionArguments = toolCall.function().arguments();
  10.         if (!this.functionCallbackRegister.containsKey(functionName)) {
  11.             throw new IllegalStateException("No function callback found for function name: " + functionName);
  12.         }
  13.         // 找到FunctionCallback并进行调用,得到工具执行结果
  14.         String functionResponse = this.functionCallbackRegister.get(functionName)
  15.             .call(functionArguments);
  16.         // 将工具执行结果添加到对话历史
  17.         // Add the function response to the conversation.
  18.         conversationHistory
  19.         .add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
  20.     }
  21.     // 构造新的请求,将工具执行结果传递给OpenAi
  22.     // Recursively call chatCompletionWithTools until the model doesn't call a
  23.     // functions anymore.
  24.     ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
  25.     newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);
  26.     return newRequest;
  27. }
复制代码
以上源码的焦点流程为:

  • 遍历每个要调用的工具
  • 根据工具名找到FunctionCallback并进行调用,得到工具执行结果
  • 将工具执行结果添加到对话历史
  • 构造新的哀求,将工具执行结果转达给OpenAi
得到新的哀求对象后,又会调用callWithFunctionSupport()方法,以是这里出现了递归调用。
函数调用

当调用FunctionCallback的call方法时,就是在执行函数调用:
  1. @Override
  2. public String call(String functionArguments) {
  3.     // 将OpenAi给的请求参数转成指定类,比如Request
  4.     // Convert the tool calls JSON arguments into a Java function request object.
  5.     I request = fromJson(functionArguments, inputType);
  6.     // 然后执行apply方法
  7.     // extend conversation with function response.
  8.     return this.andThen(this.responseConverter).apply(request);
  9. }
复制代码
从这里可以发现,对于工具执行结果,还可以设置responseConverter来进行处理,好比:
  1. @GetMapping("/functionCallback")
  2. public String functionCallback(@RequestParam String message) {
  3.     SystemMessage systemMessage = new SystemMessage("请用中文回答我");
  4.     UserMessage userMessage = new UserMessage(message);
  5.     Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder().withFunctionCallbacks(
  6.         List.of(FunctionCallbackWrapper.builder(new DateService())
  7.                 .withName("dateService")
  8.                 .withDescription("获取指定地点的当前时间")
  9.                 .withResponseConverter(response -> "2024年10月16日09:22")
  10.                 .build())
  11.     ).build());
  12.     Generation generation = chatClient.call(prompt).getResult();
  13.     return (generation != null) ? generation.getOutput().getContent() : "";
  14. }
复制代码
这样做,最终函数执行结果被我固定成了"2024年10月16日09:22,因此最终OpenAi给我答案也是
交互流程图


为什么OpenAiChatClient不在当地先直接执行工具,然后再哀求OpenAiServer呢?
以上场景比较简单,实际上的思想是:把OpenAi当做一个大脑,通过第一次哀求告诉OpenAi我的需求使命,以及我们提供了哪些工具,然后由OpenAi:

  • 先理解使命
  • 然后制定策略,也就是OpenAi要完成使命,需要调用哪些工具,并且调用这些工具的详细参数是什么,调用工具的次序是什么,这些都由OpenAi来进行分析
  • 然后OpenAi就向OpenAiChatClient发送工具调用哀求,并得到工具执行结果
  • 然后OpenAi再基于使命和工具执行结果进行分析,看是否能完成使命了,还是需要继续调用工具。
  • 如果能完成使命了,那就直接把使命的执行结果返回给OpenAiChatClient。
三丶 案例

需求:获取今天注册的新用户信息。
定义获取当前时间工具:
  1. package com.qjc.demo.service;
  2. import io.swagger.v3.oas.annotations.media.Schema;
  3. import org.springframework.context.annotation.Description;
  4. import org.springframework.stereotype.Component;
  5. import java.time.LocalDateTime;
  6. import java.util.function.Function;
  7. /***
  8. * @projectName spring-ai-demo
  9. * @packageName com.qjc.demo.service
  10. * @author qjc
  11. * @description TODO
  12. * @Email qjc1024@aliyun.com
  13. * @date 2024-10-16 10:01
  14. **/
  15. @Component
  16. @Description("获取当前时间")
  17. public class DateService implements Function<DateService.Request, String> {
  18.     public record Request(String noUse) { }
  19.     @Override
  20.     public String apply(Request request) {
  21.         System.out.println("执行DateService工具");
  22.         return LocalDateTime.now().toString();
  23.     }
  24. }
复制代码
定义获取用户信息服务:
  1. package com.qjc.demo.service;
  2. import org.springframework.context.annotation.Description;
  3. import org.springframework.stereotype.Component;
  4. import java.util.List;
  5. import java.util.function.Function;
  6. /***
  7. * @projectName spring-ai-demo
  8. * @packageName com.qjc.demo.service
  9. * @author qjc
  10. * @description TODO
  11. * @Email qjc1024@aliyun.com
  12. * @date 2024-10-16 10:05
  13. **/
  14. @Component
  15. @Description("获取指定时间的注册用户")
  16. public class UserService implements Function<UserService.Request, List<UserService.User>> {
  17.     public record Request(String date) { }
  18.     @Override
  19.     public List<User> apply(Request request) {
  20.         System.out.println("执行OrderService工具, 入参为:" + request.date);
  21.         return List.of(new User("小齐", "2024年10月16号"), new User("宇将军", "2024年10月16号"));
  22.     }
  23.     class User {
  24.         private String username;
  25.         private String registrationDate;
  26.         public User(String username, String registrationDate) {
  27.             this.username = username;
  28.             this.registrationDate = registrationDate;
  29.         }
  30.         public String getUsername() {
  31.             return username;
  32.         }
  33.         public void setUsername(String username) {
  34.             this.username = username;
  35.         }
  36.         public String getRegistrationDate() {
  37.             return registrationDate;
  38.         }
  39.         public void setRegistrationDate(String registrationDate) {
  40.             this.registrationDate = registrationDate;
  41.         }
  42.     }
  43. }
复制代码
定义哀求接口:
  1. @GetMapping("/user")
  2. public String user(@RequestParam String message) {
  3.     SystemMessage systemMessage = new SystemMessage("将结果按JSON格式返回");
  4.     UserMessage userMessage = new UserMessage(message);
  5.     Prompt prompt = new Prompt(List.of(systemMessage, userMessage), OpenAiChatOptions.builder()
  6.                                .withFunctions(Set.of("dateService", "userService"))
  7.                                .build());
  8.     Generation generation = chatClient.call(prompt).getResult();
  9.     return (generation != null) ? generation.getOutput().getContent() : "";
  10. }
复制代码
总结

我感觉非常爽啊,我的大Spring 函数调用,没有冗余代码。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

干翻全岛蛙蛙

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表