from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, Union
from pydantic import BaseModel, Field, computed_field, field_validator
from bigdata_client.connection_protocol import BigdataConnectionProtocol
from bigdata_client.constants import MAX_CHAT_QUESTION_LENGTH, MIN_CHAT_QUESTION_LENGTH
from bigdata_client.enum_utils import StrEnum
from bigdata_client.exceptions import BigdataClientChatInvalidQuestion
[docs]
class ChatScope(StrEnum):
EARNING_CALLS = "transcripts"
FILES = "files"
NEWS = "news"
REGULATORY_FILINGS = "filings"
[docs]
class ChatSource(BaseModel):
"""Represents a source in a chat message"""
id: str
headline: str
url: Optional[str]
document_scope: Optional[str]
rp_provider_id: Optional[str]
[docs]
class ChatInteraction(BaseModel):
"""Represents a single interaction with chat"""
question: str
answer: str
interaction_timestamp: str
date_created: datetime
last_updated: datetime
scope: Optional[ChatScope] = None
sources: list[ChatSource] = Field(default_factory=list)
[docs]
@field_validator("scope", mode="before")
@classmethod
def validate_scope(cls, value):
if isinstance(value, str):
try:
return ChatScope(value)
except ValueError:
return None
return value
def handle_ws_chat_response(
chat, question, scope, complete_message, sources, formatter
):
from bigdata_client.api.chat import ChatInteraction as ApiChatInteraction
answer = complete_message.content_block.get("value", "")
parsed_answer = ApiChatInteraction._parse_references(answer, sources, formatter)
now = datetime.utcnow()
interation = ChatInteraction(
question=question,
answer=parsed_answer,
interaction_timestamp=complete_message.interaction_timestamp,
date_created=now,
last_updated=now,
scope=scope,
sources=sources,
)
chat._interactions.append(interation)
return interation
[docs]
class StreamingChatInteraction(ChatInteraction):
"""
Represents a streaming interaction in a chat session.
This class handles live interactions where the response is obtained in chunks,
allowing for real-time processing while the interaction is ongoing.
"""
_chat: "Chat"
_formatter: InlineAttributionFormatter
_response: Optional[iter] = None
[docs]
def __init__(self, _chat, _formatter, **values):
super().__init__(**values)
# self.question = _question
# self.chat_scope = _chat_scope
self._chat = _chat
self._formatter = _formatter
self._response = None
def __iter__(self):
self._response = self._chat._api_connection.ask_chat(
self._chat.id, self.question, scope=self.scope
)
return self
def __next__(self):
from bigdata_client.api.chat import ChatInteraction as ApiChatInteraction
from bigdata_client.api.chat import (
ChatWSAuditTraceResponse,
ChatWSCompleteResponse,
)
try:
chunk = next(self._response)
if isinstance(chunk, str):
parsed_chunk = ApiChatInteraction._parse_references(
chunk, self.sources, self._formatter
)
return parsed_chunk
elif isinstance(chunk, ChatWSCompleteResponse):
ws_chat_response = chunk
interaction = handle_ws_chat_response(
chat=self._chat,
question=self.question,
scope=self.scope,
complete_message=ws_chat_response,
sources=self.sources,
formatter=self._formatter,
)
self.question = interaction.question
self.answer = interaction.answer
self.interaction_timestamp = interaction.interaction_timestamp
self.date_created = interaction.date_created
self.last_updated = interaction.last_updated
self.scope = interaction.scope
self.sources = interaction.sources
return ""
elif isinstance(chunk, ChatWSAuditTraceResponse):
self.sources.extend(chunk.to_chat_source())
return ""
else:
return ""
except StopIteration:
raise
[docs]
class Chat(BaseModel):
id: str
name: str
user_id: str
date_created: datetime
last_updated: datetime
@computed_field
@property
def interactions(self) -> list[ChatInteraction]:
if not self._loaded:
self.reload_from_server()
return self._interactions
_api_connection: BigdataConnectionProtocol
_interactions: list[ChatInteraction]
_formatter: InlineAttributionFormatter
_loaded: bool
def __init__(
self,
_api_connection: BigdataConnectionProtocol,
_interactions: Optional[list[ChatInteraction]],
_formatter: Optional[InlineAttributionFormatter],
_loaded: bool = False,
**values,
):
super().__init__(**values)
self._api_connection = _api_connection
self._loaded = _loaded
if _interactions is not None:
self._interactions = _interactions
self._formatter = _formatter or DefaultFormatter()
[docs]
def ask(
self,
question: str,
*,
scope: Optional[ChatScope] = None,
formatter: Optional[InlineAttributionFormatter] = None,
streaming: bool = False,
) -> Union[ChatInteraction, StreamingChatInteraction]:
"""Ask a question in the chat"""
self._validate_question(question)
formatter = formatter or self._formatter
chat_scope = scope.value if scope else None
if streaming:
now = datetime.utcnow()
return StreamingChatInteraction(
question=question,
answer="",
interaction_timestamp=now.isoformat() + "Z",
date_created=now,
last_updated=now,
scope=chat_scope,
_chat=self,
_formatter=formatter,
)
response = self._api_connection.ask_chat(self.id, question, scope=chat_scope)
from bigdata_client.api.chat import (
ChatWSAuditTraceResponse,
ChatWSCompleteResponse,
)
complete_message = None
sources = []
for chunk in response:
if isinstance(chunk, str):
pass
elif isinstance(chunk, ChatWSCompleteResponse):
complete_message = chunk
elif isinstance(chunk, ChatWSAuditTraceResponse):
sources.extend(chunk.to_chat_source())
interation = handle_ws_chat_response(
chat=self,
question=question,
scope=scope,
complete_message=complete_message,
sources=sources,
formatter=formatter,
)
return interation
def reload_from_server(self):
chat = self._api_connection.get_chat(self.id).to_chat_model(
self._api_connection, self._formatter
)
self.name = chat.name
self.user_id = chat.user_id
self.date_created = chat.date_created
self.last_updated = chat.last_updated
self._interactions = chat._interactions
self._loaded = True
[docs]
def delete(self):
"""Delete the chat"""
self._api_connection.delete_chat(self.id)
@staticmethod
def _validate_question(question: str):
message_length = len(question or "")
if not (MIN_CHAT_QUESTION_LENGTH < message_length < MAX_CHAT_QUESTION_LENGTH):
raise BigdataClientChatInvalidQuestion(message_length)