Source code for bigdata_client.models.chat

from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, Union

from pydantic import BaseModel, Field, computed_field, field_validator

from bigdata_client.connection_protocol import BigdataConnectionProtocol
from bigdata_client.constants import MAX_CHAT_QUESTION_LENGTH, MIN_CHAT_QUESTION_LENGTH
from bigdata_client.enum_utils import StrEnum
from bigdata_client.exceptions import BigdataClientChatInvalidQuestion


[docs] class ChatScope(StrEnum): EARNING_CALLS = "transcripts" FILES = "files" NEWS = "news" REGULATORY_FILINGS = "filings"
[docs] class ChatSource(BaseModel): """Represents a source in a chat message""" id: str headline: str url: Optional[str] document_scope: Optional[str] rp_provider_id: Optional[str]
[docs] class InlineAttributionFormatter(ABC): """Interface for formatting inline attributions in chat messages"""
[docs] @abstractmethod def format(self, index: int, source: ChatSource) -> str: """ Format an inline attribution. Args: index (int): The index of the attribution within the list of attributions. source (ChatSource): The inline attribution to format. Returns: str: A string representing the formatted attribution. """
[docs] class DefaultFormatter(InlineAttributionFormatter): """Default formatter for inline attributions in chat messages"""
[docs] def format(self, index: int, source: ChatSource) -> str: """ Format an inline attribution using a default reference style. Args: index (int): The index of the attribution within the list of attributions. source (ChatSource): The inline attribution to format. Returns: str: A string representing the formatted attribution in default reference style. """ return f"`:ref[{index}]` "
[docs] class MarkdownLinkFormatter(InlineAttributionFormatter): """Formatter for inline attributions in chat messages that uses Markdown links"""
[docs] def __init__( self, headline_length: Optional[int] = None, skip_empty_urls: bool = True ): """ Initialize the MarkdownLinkInlineAttributionFormatter. Args: headline_length (int): The maximum length of the headline to be displayed in the link. Default is 10. """ self.headline_length = headline_length
[docs] def format(self, index: int, source: ChatSource) -> str: """ Format an inline attribution as a Markdown link. Args: index (int): The index of the attribution within the list of attributions. source (ChatSource): The inline attribution to format. Returns: str: A string representing the formatted attribution as a Markdown link. """ hd = source.headline if self.headline_length: hd = source.headline[: self.headline_length] url = source.url or "" if url == "": return "" return f"[{hd}]({url}) "
[docs] class ChatInteraction(BaseModel): """Represents a single interaction with chat""" question: str answer: str interaction_timestamp: str date_created: datetime last_updated: datetime scope: Optional[ChatScope] = None sources: list[ChatSource] = Field(default_factory=list)
[docs] @field_validator("scope", mode="before") @classmethod def validate_scope(cls, value): if isinstance(value, str): try: return ChatScope(value) except ValueError: return None return value
def handle_ws_chat_response( chat, question, scope, complete_message, sources, formatter ): from bigdata_client.api.chat import ChatInteraction as ApiChatInteraction answer = complete_message.content_block.get("value", "") parsed_answer = ApiChatInteraction._parse_references(answer, sources, formatter) now = datetime.utcnow() interation = ChatInteraction( question=question, answer=parsed_answer, interaction_timestamp=complete_message.interaction_timestamp, date_created=now, last_updated=now, scope=scope, sources=sources, ) chat._interactions.append(interation) return interation
[docs] class StreamingChatInteraction(ChatInteraction): """ Represents a streaming interaction in a chat session. This class handles live interactions where the response is obtained in chunks, allowing for real-time processing while the interaction is ongoing. """ _chat: "Chat" _formatter: InlineAttributionFormatter _response: Optional[iter] = None
[docs] def __init__(self, _chat, _formatter, **values): super().__init__(**values) # self.question = _question # self.chat_scope = _chat_scope self._chat = _chat self._formatter = _formatter self._response = None
def __iter__(self): self._response = self._chat._api_connection.ask_chat( self._chat.id, self.question, scope=self.scope ) return self def __next__(self): from bigdata_client.api.chat import ChatInteraction as ApiChatInteraction from bigdata_client.api.chat import ( ChatWSAuditTraceResponse, ChatWSCompleteResponse, ) try: chunk = next(self._response) if isinstance(chunk, str): parsed_chunk = ApiChatInteraction._parse_references( chunk, self.sources, self._formatter ) return parsed_chunk elif isinstance(chunk, ChatWSCompleteResponse): ws_chat_response = chunk interaction = handle_ws_chat_response( chat=self._chat, question=self.question, scope=self.scope, complete_message=ws_chat_response, sources=self.sources, formatter=self._formatter, ) self.question = interaction.question self.answer = interaction.answer self.interaction_timestamp = interaction.interaction_timestamp self.date_created = interaction.date_created self.last_updated = interaction.last_updated self.scope = interaction.scope self.sources = interaction.sources return "" elif isinstance(chunk, ChatWSAuditTraceResponse): self.sources.extend(chunk.to_chat_source()) return "" else: return "" except StopIteration: raise
[docs] class Chat(BaseModel): id: str name: str user_id: str date_created: datetime last_updated: datetime @computed_field @property def interactions(self) -> list[ChatInteraction]: if not self._loaded: self.reload_from_server() return self._interactions _api_connection: BigdataConnectionProtocol _interactions: list[ChatInteraction] _formatter: InlineAttributionFormatter _loaded: bool def __init__( self, _api_connection: BigdataConnectionProtocol, _interactions: Optional[list[ChatInteraction]], _formatter: Optional[InlineAttributionFormatter], _loaded: bool = False, **values, ): super().__init__(**values) self._api_connection = _api_connection self._loaded = _loaded if _interactions is not None: self._interactions = _interactions self._formatter = _formatter or DefaultFormatter()
[docs] def ask( self, question: str, *, scope: Optional[ChatScope] = None, formatter: Optional[InlineAttributionFormatter] = None, streaming: bool = False, ) -> Union[ChatInteraction, StreamingChatInteraction]: """Ask a question in the chat""" self._validate_question(question) formatter = formatter or self._formatter chat_scope = scope.value if scope else None if streaming: now = datetime.utcnow() return StreamingChatInteraction( question=question, answer="", interaction_timestamp=now.isoformat() + "Z", date_created=now, last_updated=now, scope=chat_scope, _chat=self, _formatter=formatter, ) response = self._api_connection.ask_chat(self.id, question, scope=chat_scope) from bigdata_client.api.chat import ( ChatWSAuditTraceResponse, ChatWSCompleteResponse, ) complete_message = None sources = [] for chunk in response: if isinstance(chunk, str): pass elif isinstance(chunk, ChatWSCompleteResponse): complete_message = chunk elif isinstance(chunk, ChatWSAuditTraceResponse): sources.extend(chunk.to_chat_source()) interation = handle_ws_chat_response( chat=self, question=question, scope=scope, complete_message=complete_message, sources=sources, formatter=formatter, ) return interation
def reload_from_server(self): chat = self._api_connection.get_chat(self.id).to_chat_model( self._api_connection, self._formatter ) self.name = chat.name self.user_id = chat.user_id self.date_created = chat.date_created self.last_updated = chat.last_updated self._interactions = chat._interactions self._loaded = True
[docs] def delete(self): """Delete the chat""" self._api_connection.delete_chat(self.id)
@staticmethod def _validate_question(question: str): message_length = len(question or "") if not (MIN_CHAT_QUESTION_LENGTH < message_length < MAX_CHAT_QUESTION_LENGTH): raise BigdataClientChatInvalidQuestion(message_length)