llm_provider.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. """
  2. AI服务提供商基类和多平台实现
  3. """
  4. from abc import ABC, abstractmethod
  5. from typing import List, Dict, Optional, AsyncGenerator
  6. import openai
  7. import anthropic
  8. import dashscope
  9. import httpx
  10. from app.core.config import settings
  11. class BaseLLMProvider(ABC):
  12. """AI服务提供商抽象基类"""
  13. @abstractmethod
  14. async def chat_completion(
  15. self,
  16. messages: List[Dict[str, str]],
  17. api_key: str,
  18. model: str = None,
  19. **kwargs
  20. ) -> str:
  21. """同步对话补全"""
  22. pass
  23. @abstractmethod
  24. async def stream_completion(
  25. self,
  26. messages: List[Dict[str, str]],
  27. api_key: str,
  28. model: str = None,
  29. **kwargs
  30. ) -> AsyncGenerator[str, None]:
  31. """流式对话补全"""
  32. pass
  33. class OpenAIProvider(BaseLLMProvider):
  34. """OpenAI服务提供商"""
  35. async def chat_completion(
  36. self,
  37. messages: List[Dict[str, str]],
  38. api_key: str,
  39. model: str = "gpt-3.5-turbo",
  40. **kwargs
  41. ) -> str:
  42. client = openai.AsyncOpenAI(api_key=api_key)
  43. response = await client.chat.completions.create(
  44. model=model,
  45. messages=messages,
  46. **kwargs
  47. )
  48. return response.choices[0].message.content
  49. async def stream_completion(
  50. self,
  51. messages: List[Dict[str, str]],
  52. api_key: str,
  53. model: str = "gpt-3.5-turbo",
  54. **kwargs
  55. ) -> AsyncGenerator[str, None]:
  56. client = openai.AsyncOpenAI(api_key=api_key)
  57. stream = await client.chat.completions.create(
  58. model=model,
  59. messages=messages,
  60. stream=True,
  61. **kwargs
  62. )
  63. async for chunk in stream:
  64. if chunk.choices[0].delta.content:
  65. yield chunk.choices[0].delta.content
  66. class ClaudeProvider(BaseLLMProvider):
  67. """Claude服务提供商"""
  68. async def chat_completion(
  69. self,
  70. messages: List[Dict[str, str]],
  71. api_key: str,
  72. model: str = "claude-3-sonnet-20240229",
  73. **kwargs
  74. ) -> str:
  75. client = anthropic.AsyncAnthropic(api_key=api_key)
  76. # 提取system消息
  77. system_msg = next((m["content"] for m in messages if m["role"] == "system"), None)
  78. user_messages = [m for m in messages if m["role"] != "system"]
  79. response = await client.messages.create(
  80. model=model,
  81. system=system_msg,
  82. messages=user_messages,
  83. max_tokens=kwargs.get("max_tokens", 1024),
  84. **{k: v for k, v in kwargs.items() if k != "max_tokens"}
  85. )
  86. return response.content[0].text
  87. async def stream_completion(
  88. self,
  89. messages: List[Dict[str, str]],
  90. api_key: str,
  91. model: str = "claude-3-sonnet-20240229",
  92. **kwargs
  93. ) -> AsyncGenerator[str, None]:
  94. client = anthropic.AsyncAnthropic(api_key=api_key)
  95. system_msg = next((m["content"] for m in messages if m["role"] == "system"), None)
  96. user_messages = [m for m in messages if m["role"] != "system"]
  97. async with client.messages.stream(
  98. model=model,
  99. system=system_msg,
  100. messages=user_messages,
  101. max_tokens=kwargs.get("max_tokens", 1024),
  102. ) as stream:
  103. async for text in stream.text_stream:
  104. yield text
  105. class QwenProvider(BaseLLMProvider):
  106. """通义千问服务提供商"""
  107. async def chat_completion(
  108. self,
  109. messages: List[Dict[str, str]],
  110. api_key: str,
  111. model: str = "qwen-turbo",
  112. **kwargs
  113. ) -> str:
  114. dashscope.api_key = api_key
  115. response = await dashscope.Generation.call(
  116. model=model,
  117. messages=messages,
  118. result_format="message",
  119. **kwargs
  120. )
  121. return response.output.choices[0].message.content
  122. async def stream_completion(
  123. self,
  124. messages: List[Dict[str, str]],
  125. api_key: str,
  126. model: str = "qwen-turbo",
  127. **kwargs
  128. ) -> AsyncGenerator[str, None]:
  129. dashscope.api_key = api_key
  130. responses = dashscope.Generation.call(
  131. model=model,
  132. messages=messages,
  133. result_format="message",
  134. stream=True,
  135. **kwargs
  136. )
  137. for response in responses:
  138. if response.status_code == 200:
  139. yield response.output.choices[0].message.content
  140. class ErnieProvider(BaseLLMProvider):
  141. """文心一言服务提供商"""
  142. async def chat_completion(
  143. self,
  144. messages: List[Dict[str, str]],
  145. api_key: str,
  146. model: str = "ernie-bot-turbo",
  147. **kwargs
  148. ) -> str:
  149. # 文心一言需要先获取access_token
  150. access_token = await self._get_access_token(api_key)
  151. async with httpx.AsyncClient() as client:
  152. response = await client.post(
  153. f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
  154. params={"access_token": access_token},
  155. json={"messages": messages, **kwargs}
  156. )
  157. data = response.json()
  158. return data["result"]
  159. async def stream_completion(
  160. self,
  161. messages: List[Dict[str, str]],
  162. api_key: str,
  163. model: str = "ernie-bot-turbo",
  164. **kwargs
  165. ) -> AsyncGenerator[str, None]:
  166. access_token = await self._get_access_token(api_key)
  167. async with httpx.AsyncClient() as client:
  168. async with client.stream(
  169. "POST",
  170. f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
  171. params={"access_token": access_token},
  172. json={"messages": messages, "stream": True, **kwargs}
  173. ) as response:
  174. async for line in response.aiter_lines():
  175. if line.startswith("data: "):
  176. import json
  177. data = json.loads(line[6:])
  178. if "result" in data:
  179. yield data["result"]
  180. async def _get_access_token(self, api_key: str) -> str:
  181. """获取百度access_token(简化实现,实际需要缓存)"""
  182. # api_key格式: "api_key:secret_key"
  183. api_key_part, secret_key = api_key.split(":")
  184. async with httpx.AsyncClient() as client:
  185. response = await client.get(
  186. "https://aip.baidubce.com/oauth/2.0/token",
  187. params={
  188. "grant_type": "client_credentials",
  189. "client_id": api_key_part,
  190. "client_secret": secret_key
  191. }
  192. )
  193. return response.json()["access_token"]
  194. # 工厂函数
  195. def get_llm_provider(provider_name: str) -> BaseLLMProvider:
  196. """获取AI服务提供商实例"""
  197. providers = {
  198. "openai": OpenAIProvider(),
  199. "claude": ClaudeProvider(),
  200. "qwen": QwenProvider(),
  201. "ernie": ErnieProvider(),
  202. }
  203. if provider_name not in providers:
  204. raise ValueError(f"不支持的AI平台: {provider_name}")
  205. return providers[provider_name]