Source code for evennia.contrib.rpg.llm.llm_client

"""
LLM (Large Language Model) client, for communicating with an LLM backend. This can be used
for generating texts for AI npcs, or for fine-tuning the LLM on a given prompt.

Note that running a LLM locally requires a lot of power, and ideally a powerful GPU. Testing
this with CPU mode on a beefy laptop, still takes some 4s just on a very small model.

The server defaults to output suitable for a local server
https://github.com/oobabooga/text-generation-webui, but could be used for other LLM servers too.

See the LLM instructions on that page for how to set up the server. You'll also need
a model file - there are thousands to try out on https://huggingface.co/models (you want Text
Generation models specifically).

# Optional Evennia settings (if not given, these defaults are used)

DEFAULT_LLM_HOST = "http://localhost:5000"
DEFAULT_LLM_PATH = "/api/v1/generate"
DEFAULT_LLM_HEADERS = {"Content-Type": "application/json"}
DEFAULT_LLM_PROMPT_KEYNAME = "prompt"
DEFAULT_LLM_REQUEST_BODY = {...}   # see below, this controls how to prompt the LLM server.

"""

import json

from django.conf import settings
from twisted.internet import defer, protocol, reactor
from twisted.internet.defer import inlineCallbacks
from twisted.web.client import Agent, HTTPConnectionPool, _HTTP11ClientFactory
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer
from zope.interface import implementer

from evennia import logger
from evennia.utils.utils import make_iter

DEFAULT_LLM_HOST = "http://127.0.0.1:5000"
DEFAULT_LLM_PATH = "/api/v1/generate"
DEFAULT_LLM_HEADERS = {"Content-Type": ["application/json"]}
DEFAULT_LLM_PROMPT_KEYNAME = "prompt"
DEFAULT_LLM_API_TYPE = ""  # or openai
DEFAULT_LLM_REQUEST_BODY = {
    "max_new_tokens": 250,  # max number of tokens to generate
    "temperature": 0.7,  # higher = more random, lower = more predictable
}


[docs]@implementer(IBodyProducer) class StringProducer: """ Used for feeding a request body to the HTTP client. """
[docs] def __init__(self, body): self.body = bytes(body, "utf-8") self.length = len(body)
[docs] def startProducing(self, consumer): consumer.write(self.body) return defer.succeed(None)
[docs] def pauseProducing(self): pass
[docs] def stopProducing(self): pass
[docs]class SimpleResponseReceiver(protocol.Protocol): """ Used for pulling the response body out of an HTTP response. """
[docs] def __init__(self, status_code, d): self.status_code = status_code self.buf = b"" self.d = d
[docs] def dataReceived(self, data): self.buf += data
[docs] def connectionLost(self, reason=protocol.connectionDone): self.d.callback((self.status_code, self.buf))
[docs]class QuietHTTP11ClientFactory(_HTTP11ClientFactory): """ Silences the obnoxious factory start/stop messages in the default client. """ noisy = False
[docs]class LLMClient: """ A client for communicating with an LLM server. """
[docs] def __init__(self, on_bad_request=None): self._conn_pool = HTTPConnectionPool(reactor) self._conn_pool._factory = QuietHTTP11ClientFactory self.prompt_keyname = getattr(settings, "LLM_PROMPT_KEYNAME", DEFAULT_LLM_PROMPT_KEYNAME) self.hostname = getattr(settings, "LLM_HOST", DEFAULT_LLM_HOST) self.pathname = getattr(settings, "LLM_PATH", DEFAULT_LLM_PATH) self.headers = getattr(settings, "LLM_HEADERS", DEFAULT_LLM_HEADERS) self.request_body = getattr(settings, "LLM_REQUEST_BODY", DEFAULT_LLM_REQUEST_BODY) self.api_type = getattr(settings, "LLM_API_TYPE", DEFAULT_LLM_API_TYPE) self.agent = Agent(reactor, pool=self._conn_pool)
def _format_request_body(self, prompt): """Structure the request body for the LLM server""" request_body = self.request_body.copy() prompt = "\n".join(make_iter(prompt)) request_body[self.prompt_keyname] = prompt return request_body def _handle_llm_response_body(self, response): """Get the response body from the response""" d = defer.Deferred() response.deliverBody(SimpleResponseReceiver(response.code, d)) return d def _handle_llm_error(self, failure): """Correctly handle server connection errors""" failure.trap(Exception) return (500, failure.getErrorMessage()) def _get_response_from_llm_server(self, prompt): """Call the LLM server and handle the response/failure""" request_body = self._format_request_body(prompt) if settings.DEBUG: logger.log_info(f"LLM request body: {request_body}") d = self.agent.request( b"POST", bytes(self.hostname + self.pathname, "utf-8"), headers=Headers(self.headers), bodyProducer=StringProducer(json.dumps(request_body)), ) d.addCallbacks(self._handle_llm_response_body, self._handle_llm_error) return d
[docs] @inlineCallbacks def get_response(self, prompt): """ Get a response from the LLM server for the given npc. Args: prompt (str or list): The prompt to send to the LLM server. If a list, this is assumed to be the chat history so far, and will be added to the prompt in a way suitable for the api. Returns: str: The generated text response. Will return an empty string if there is an issue with the server, in which case the the caller is expected to handle this gracefully. """ status_code, response = yield self._get_response_from_llm_server(prompt) if status_code == 200: if settings.DEBUG: logger.log_info(f"LLM response: {response}") return json.loads(response)["results"][0]["text"] else: logger.log_err(f"LLM API error (status {status_code}): {response}") return ""