#!/usr/bin/env python import json import sys # pip install websocket-client import websocket class ModelClient(object): def __init__(self, endpoint_url): self.endpoint_url = endpoint_url self.ws = None self.model = None def open_session(self, model, max_length): self.ws = websocket.create_connection(self.endpoint_url, enable_multithread=True) self.model = model payload = { "type": "open_inference_session", "model": self.model, "max_length": max_length, } self.ws.send(json.dumps(payload)) assert json.loads(self.ws.recv())['ok'] == True def is_session(self): return self.ws != None def close_session(self): if self.ws: self.ws.close() self.ws = None def generate(self, prompt, **kwargs): try: return self._generate(prompt, **kwargs) except: self.close_session() raise def _generate(self, prompt, **kwargs): payload = { "type": "generate", "inputs": prompt, "max_new_tokens": 1, "do_sample": 0, "temperature": 1, "stop_sequence": "" if "bloomz" in self.model else "\n\n", } payload = {**payload, **kwargs} self.ws.send(json.dumps(payload)) while True: data = json.loads(self.ws.recv()) if not data['ok']: raise Exception(data['traceback']) yield data['outputs'] if data['stop']: break def main(): #client = ModelClient("ws://localhost:8000/api/v2/generate") client = ModelClient("wss://chat.petals.dev/api/v2/generate") client.open_session("stabilityai/StableBeluga2", 128) if len(sys.argv) > 1: prompt = sys.argv[1] # Bloomz variant uses instead of \n\n as an eos token if not prompt.endswith("\n\n"): prompt += "\n\n" else: prompt = "The SQL command to extract all the users whose name starts with A is: \n\n" print(f"Prompt: {prompt}") # petals.client.routing.sequence_manager.MissingBlocksError for out in client.generate(prompt, do_sample=True, temperature=0.75, top_p=0.9): print(out, end="", flush=True) client.close_session() if __name__ == '__main__': main()