| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- """
- AI服务提供商基类和多平台实现
- """
- from abc import ABC, abstractmethod
- from typing import List, Dict, Optional, AsyncGenerator
- import openai
- import anthropic
- import dashscope
- import httpx
- from app.core.config import settings
- class BaseLLMProvider(ABC):
- """AI服务提供商抽象基类"""
- @abstractmethod
- async def chat_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = None,
- **kwargs
- ) -> str:
- """同步对话补全"""
- pass
- @abstractmethod
- async def stream_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = None,
- **kwargs
- ) -> AsyncGenerator[str, None]:
- """流式对话补全"""
- pass
- class OpenAIProvider(BaseLLMProvider):
- """OpenAI服务提供商"""
- async def chat_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = "gpt-3.5-turbo",
- **kwargs
- ) -> str:
- client = openai.AsyncOpenAI(api_key=api_key)
- response = await client.chat.completions.create(
- model=model,
- messages=messages,
- **kwargs
- )
- return response.choices[0].message.content
- async def stream_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = "gpt-3.5-turbo",
- **kwargs
- ) -> AsyncGenerator[str, None]:
- client = openai.AsyncOpenAI(api_key=api_key)
- stream = await client.chat.completions.create(
- model=model,
- messages=messages,
- stream=True,
- **kwargs
- )
- async for chunk in stream:
- if chunk.choices[0].delta.content:
- yield chunk.choices[0].delta.content
- class ClaudeProvider(BaseLLMProvider):
- """Claude服务提供商"""
- async def chat_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = "claude-3-sonnet-20240229",
- **kwargs
- ) -> str:
- client = anthropic.AsyncAnthropic(api_key=api_key)
- # 提取system消息
- system_msg = next((m["content"] for m in messages if m["role"] == "system"), None)
- user_messages = [m for m in messages if m["role"] != "system"]
- response = await client.messages.create(
- model=model,
- system=system_msg,
- messages=user_messages,
- max_tokens=kwargs.get("max_tokens", 1024),
- **{k: v for k, v in kwargs.items() if k != "max_tokens"}
- )
- return response.content[0].text
- async def stream_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = "claude-3-sonnet-20240229",
- **kwargs
- ) -> AsyncGenerator[str, None]:
- client = anthropic.AsyncAnthropic(api_key=api_key)
- system_msg = next((m["content"] for m in messages if m["role"] == "system"), None)
- user_messages = [m for m in messages if m["role"] != "system"]
- async with client.messages.stream(
- model=model,
- system=system_msg,
- messages=user_messages,
- max_tokens=kwargs.get("max_tokens", 1024),
- ) as stream:
- async for text in stream.text_stream:
- yield text
- class QwenProvider(BaseLLMProvider):
- """通义千问服务提供商"""
- async def chat_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = "qwen-turbo",
- **kwargs
- ) -> str:
- dashscope.api_key = api_key
- response = await dashscope.Generation.call(
- model=model,
- messages=messages,
- result_format="message",
- **kwargs
- )
- return response.output.choices[0].message.content
- async def stream_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = "qwen-turbo",
- **kwargs
- ) -> AsyncGenerator[str, None]:
- dashscope.api_key = api_key
- responses = dashscope.Generation.call(
- model=model,
- messages=messages,
- result_format="message",
- stream=True,
- **kwargs
- )
- for response in responses:
- if response.status_code == 200:
- yield response.output.choices[0].message.content
- class ErnieProvider(BaseLLMProvider):
- """文心一言服务提供商"""
- async def chat_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = "ernie-bot-turbo",
- **kwargs
- ) -> str:
- # 文心一言需要先获取access_token
- access_token = await self._get_access_token(api_key)
- async with httpx.AsyncClient() as client:
- response = await client.post(
- f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
- params={"access_token": access_token},
- json={"messages": messages, **kwargs}
- )
- data = response.json()
- return data["result"]
- async def stream_completion(
- self,
- messages: List[Dict[str, str]],
- api_key: str,
- model: str = "ernie-bot-turbo",
- **kwargs
- ) -> AsyncGenerator[str, None]:
- access_token = await self._get_access_token(api_key)
- async with httpx.AsyncClient() as client:
- async with client.stream(
- "POST",
- f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model}",
- params={"access_token": access_token},
- json={"messages": messages, "stream": True, **kwargs}
- ) as response:
- async for line in response.aiter_lines():
- if line.startswith("data: "):
- import json
- data = json.loads(line[6:])
- if "result" in data:
- yield data["result"]
- async def _get_access_token(self, api_key: str) -> str:
- """获取百度access_token(简化实现,实际需要缓存)"""
- # api_key格式: "api_key:secret_key"
- api_key_part, secret_key = api_key.split(":")
- async with httpx.AsyncClient() as client:
- response = await client.get(
- "https://aip.baidubce.com/oauth/2.0/token",
- params={
- "grant_type": "client_credentials",
- "client_id": api_key_part,
- "client_secret": secret_key
- }
- )
- return response.json()["access_token"]
- # 工厂函数
- def get_llm_provider(provider_name: str) -> BaseLLMProvider:
- """获取AI服务提供商实例"""
- providers = {
- "openai": OpenAIProvider(),
- "claude": ClaudeProvider(),
- "qwen": QwenProvider(),
- "ernie": ErnieProvider(),
- }
- if provider_name not in providers:
- raise ValueError(f"不支持的AI平台: {provider_name}")
- return providers[provider_name]
|