Spaces:
zzzzred
/
Runtime error

Reiner4 commited on
Commit
40faca4
Β·
1 Parent(s): 3aa87fd

Upload 4 files

Browse files
Files changed (3) hide show
  1. constants.py +7 -7
  2. requirements-complete.txt +19 -0
  3. server.py +158 -36
constants.py CHANGED
@@ -1,18 +1,18 @@
1
  # Constants
2
- # Also try: 'slauw87/bart-large-cnn-samsum'
3
- DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ElectrifAi_v14"
4
- # Also try: 'nateraw/bert-base-uncased-emotion'
5
- DEFAULT_CLASSIFICATION_MODEL = "joeddav/distilbert-base-uncased-go-emotions-student"
 
6
  # Also try: 'Salesforce/blip-image-captioning-base'
7
  DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
8
- # Also try: 'ckpt/anything-v4.5-vae-swapped'
9
- DEFAULT_SD_MODEL = "sinkinai/MeinaHentai-v3-baked-vae"
10
  DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
11
  DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
12
  DEFAULT_REMOTE_SD_PORT = 7860
13
  DEFAULT_CHROMA_PORT = 8000
14
  SILERO_SAMPLES_PATH = "tts_samples"
15
- SILERO_SAMPLE_TEXT = "Doctor is your lord and savior"
16
  # ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
17
  DEFAULT_SUMMARIZE_PARAMS = {
18
  "temperature": 1.0,
 
1
  # Constants
2
+ DEFAULT_CUDA_DEVICE = "cuda:0"
3
+ # Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
4
+ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
5
+ # Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
6
+ DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
7
  # Also try: 'Salesforce/blip-image-captioning-base'
8
  DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
9
+ DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
 
10
  DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
11
  DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
12
  DEFAULT_REMOTE_SD_PORT = 7860
13
  DEFAULT_CHROMA_PORT = 8000
14
  SILERO_SAMPLES_PATH = "tts_samples"
15
+ SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
16
  # ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
17
  DEFAULT_SUMMARIZE_PARAMS = {
18
  "temperature": 1.0,
requirements-complete.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cloudflared
3
+ flask-cors
4
+ flask-compress
5
+ markdown
6
+ Pillow
7
+ colorama
8
+ webuiapi
9
+ --extra-index-url https://download.pytorch.org/whl/cu117
10
+ torch==2.0.0+cu117
11
+ torchvision==0.15.1
12
+ torchaudio==2.0.1+cu117
13
+ accelerate
14
+ transformers==4.28.1
15
+ diffusers==0.16.1
16
+ silero-api-server
17
+ chromadb
18
+ sentence_transformers
19
+ edge-tts
server.py CHANGED
@@ -21,6 +21,7 @@ import torch
21
  import time
22
  import os
23
  import gc
 
24
  import secrets
25
  from PIL import Image
26
  import base64
@@ -33,6 +34,9 @@ from colorama import Fore, Style, init as colorama_init
33
 
34
  colorama_init()
35
 
 
 
 
36
 
37
  class SplitArgs(argparse.Action):
38
  def __call__(self, parser, namespace, values, option_string=None):
@@ -40,6 +44,16 @@ class SplitArgs(argparse.Action):
40
  namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Script arguments
45
  parser = argparse.ArgumentParser(
@@ -56,6 +70,8 @@ parser.add_argument(
56
  )
57
  parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
58
  parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
 
 
59
  parser.set_defaults(cpu=True)
60
  parser.add_argument("--summarization-model", help="Load a custom summarization model")
61
  parser.add_argument(
@@ -66,11 +82,10 @@ parser.add_argument("--embedding-model", help="Load a custom text embedding mode
66
  parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
67
  parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
68
  parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
69
- parser.add_argument('--chroma-persist', help="Chromadb persistence", default=True, action=argparse.BooleanOptionalAction)
70
  parser.add_argument(
71
  "--secure", action="store_true", help="Enforces the use of an API key"
72
  )
73
-
74
  sd_group = parser.add_mutually_exclusive_group()
75
 
76
  local_sd = sd_group.add_argument_group("sd-local")
@@ -105,8 +120,8 @@ parser.add_argument(
105
 
106
  args = parser.parse_args()
107
 
108
- port = 7860
109
- host = "0.0.0.0"
110
  summarization_model = (
111
  args.summarization_model
112
  if args.summarization_model
@@ -142,12 +157,16 @@ if len(modules) == 0:
142
  print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
143
 
144
  # Models init
145
- device_string = "cuda:0" if torch.cuda.is_available() and not args.cpu else "cpu"
 
146
  device = torch.device(device_string)
147
- torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
148
 
149
  if not torch.cuda.is_available() and not args.cpu:
150
- print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device. Defaulting to CPU mode.{Style.RESET_ALL}")
 
 
 
151
 
152
  print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
153
 
@@ -184,12 +203,10 @@ if "sd" in modules and not sd_use_remote:
184
  from diffusers import StableDiffusionPipeline
185
  from diffusers import EulerAncestralDiscreteScheduler
186
 
187
- print("Initializing Stable Diffusion pipeline")
188
- sd_device_string = (
189
- "cuda" if torch.cuda.is_available() and not args.sd_cpu else "cpu"
190
- )
191
  sd_device = torch.device(sd_device_string)
192
- sd_torch_dtype = torch.float32 if sd_device_string == "cpu" else torch.float16
193
  sd_pipe = StableDiffusionPipeline.from_pretrained(
194
  sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
195
  ).to(sd_device)
@@ -252,26 +269,19 @@ if "chromadb" in modules:
252
  posthog.capture = lambda *args, **kwargs: None
253
  if args.chroma_host is None:
254
  if args.chroma_persist:
255
- chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False, persist_directory=args.chroma_folder, chroma_db_impl='duckdb+parquet'))
256
  print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
257
  else:
258
- chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
259
  print(f"ChromaDB is running in-memory without persistence.")
260
  else:
261
  chroma_port=(
262
  args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
263
  )
264
- chromadb_client = chromadb.Client(
265
- Settings(
266
- anonymized_telemetry=False,
267
- chroma_api_impl="rest",
268
- chroma_server_host=args.chroma_host,
269
- chroma_server_http_port=chroma_port
270
- )
271
- )
272
  print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
273
 
274
- chromadb_embedder = SentenceTransformer(embedding_model)
275
  chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
276
 
277
  # Check if the db is connected and running, otherwise tell the user
@@ -405,10 +415,24 @@ def image_to_base64(image: Image, quality: int = 75) -> str:
405
  image.save(buffer, format="JPEG", quality=quality)
406
  img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
407
  return img_str
408
-
409
- ignore_auth = []
410
 
411
- api_key = os.environ.get("password")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  def is_authorize_ignored(request):
414
  view_func = app.view_functions.get(request.endpoint)
@@ -418,6 +442,7 @@ def is_authorize_ignored(request):
418
  return True
419
  return False
420
 
 
421
  @app.before_request
422
  def before_request():
423
  # Request time measuring
@@ -426,14 +451,14 @@ def before_request():
426
  # Checks if an API key is present and valid, otherwise return unauthorized
427
  # The options check is required so CORS doesn't get angry
428
  try:
429
- if request.method != 'OPTIONS' and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key:
430
  print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
431
  response = jsonify({ 'error': '401: Invalid API key' })
432
  response.status_code = 401
433
- return "this space is only for doctord98 but you can duplicate it and enjoy"
434
  except Exception as e:
435
  print(f"API key check error: {e}")
436
- return "this space is only for doctord98 but you can duplicate it and enjoy"
437
 
438
 
439
  @app.after_request
@@ -645,7 +670,7 @@ def tts_speakers():
645
  ]
646
  return jsonify(voices)
647
 
648
-
649
  @app.route("/api/tts/generate", methods=["POST"])
650
  @require_module("silero-tts")
651
  def tts_generate():
@@ -657,8 +682,15 @@ def tts_generate():
657
  # Remove asterisks
658
  voice["text"] = voice["text"].replace("*", "")
659
  try:
 
 
 
 
660
  audio = tts_service.generate(voice["speaker"], voice["text"])
661
- return send_file(audio, mimetype="audio/x-wav")
 
 
 
662
  except Exception as e:
663
  print(e)
664
  abort(500, voice["speaker"])
@@ -743,8 +775,6 @@ def chromadb_purge():
743
 
744
  count = collection.count()
745
  collection.delete()
746
- #Write deletion to persistent folder
747
- chromadb_client.persist()
748
  print("ChromaDB embeddings deleted", count)
749
  return 'Ok', 200
750
 
@@ -768,6 +798,11 @@ def chromadb_query():
768
  name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
769
  )
770
 
 
 
 
 
 
771
  n_results = min(collection.count(), n_results)
772
  query_result = collection.query(
773
  query_texts=[data["query"]],
@@ -793,6 +828,69 @@ def chromadb_query():
793
 
794
  return jsonify(messages)
795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
 
797
  @app.route("/api/chromadb/export", methods=["POST"])
798
  @require_module("chromadb")
@@ -802,9 +900,14 @@ def chromadb_export():
802
  abort(400, '"chat_id" is required')
803
 
804
  chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
805
- collection = chromadb_client.get_or_create_collection(
806
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
807
- )
 
 
 
 
 
808
  collection_content = collection.get()
809
  documents = collection_content.get('documents', [])
810
  ids = collection_content.get('ids', [])
@@ -847,8 +950,27 @@ def chromadb_import():
847
 
848
 
849
  collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
 
850
 
851
  return jsonify({"count": len(ids)})
852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
853
  ignore_auth.append(tts_play_sample)
854
  app.run(host=host, port=port)
 
21
  import time
22
  import os
23
  import gc
24
+ import sys
25
  import secrets
26
  from PIL import Image
27
  import base64
 
34
 
35
  colorama_init()
36
 
37
+ if sys.hexversion < 0x030b0000:
38
+ print(f"{Fore.BLUE}{Style.BRIGHT}Python 3.11 or newer is recommended to run this program.{Style.RESET_ALL}")
39
+ time.sleep(2)
40
 
41
  class SplitArgs(argparse.Action):
42
  def __call__(self, parser, namespace, values, option_string=None):
 
44
  namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
45
  )
46
 
47
+ #Setting Root Folders for Silero Generations so it is compatible with STSL, should not effect regular runs. - Rolyat
48
+ parent_dir = os.path.dirname(os.path.abspath(__file__))
49
+ SILERO_SAMPLES_PATH = os.path.join(parent_dir, "tts_samples")
50
+ SILERO_SAMPLE_TEXT = os.path.join(parent_dir)
51
+
52
+ # Create directories if they don't exist
53
+ if not os.path.exists(SILERO_SAMPLES_PATH):
54
+ os.makedirs(SILERO_SAMPLES_PATH)
55
+ if not os.path.exists(SILERO_SAMPLE_TEXT):
56
+ os.makedirs(SILERO_SAMPLE_TEXT)
57
 
58
  # Script arguments
59
  parser = argparse.ArgumentParser(
 
70
  )
71
  parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
72
  parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
73
+ parser.add_argument("--cuda-device", help="Specify the CUDA device to use")
74
+ parser.add_argument("--mps", "--apple", "--m1", "--m2", action="store_false", dest="cpu", help="Run the models on Apple Silicon")
75
  parser.set_defaults(cpu=True)
76
  parser.add_argument("--summarization-model", help="Load a custom summarization model")
77
  parser.add_argument(
 
82
  parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
83
  parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
84
  parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
85
+ parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=True, action=argparse.BooleanOptionalAction)
86
  parser.add_argument(
87
  "--secure", action="store_true", help="Enforces the use of an API key"
88
  )
 
89
  sd_group = parser.add_mutually_exclusive_group()
90
 
91
  local_sd = sd_group.add_argument_group("sd-local")
 
120
 
121
  args = parser.parse_args()
122
 
123
+ port = args.port if args.port else 5100
124
+ host = "0.0.0.0" if args.listen else "localhost"
125
  summarization_model = (
126
  args.summarization_model
127
  if args.summarization_model
 
157
  print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
158
 
159
  # Models init
160
+ cuda_device = DEFAULT_CUDA_DEVICE if not args.cuda_device else args.cuda_device
161
+ device_string = cuda_device if torch.cuda.is_available() and not args.cpu else 'mps' if torch.backends.mps.is_available() and not args.cpu else 'cpu'
162
  device = torch.device(device_string)
163
+ torch_dtype = torch.float32 if device_string != cuda_device else torch.float16
164
 
165
  if not torch.cuda.is_available() and not args.cpu:
166
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device.{Style.RESET_ALL}")
167
+ if not torch.backends.mps.is_available() and not args.cpu:
168
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-mps is not supported on this device.{Style.RESET_ALL}")
169
+
170
 
171
  print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
172
 
 
203
  from diffusers import StableDiffusionPipeline
204
  from diffusers import EulerAncestralDiscreteScheduler
205
 
206
+ print("Initializing Stable Diffusion pipeline...")
207
+ sd_device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
 
 
208
  sd_device = torch.device(sd_device_string)
209
+ sd_torch_dtype = torch.float32 if sd_device_string != cuda_device else torch.float16
210
  sd_pipe = StableDiffusionPipeline.from_pretrained(
211
  sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
212
  ).to(sd_device)
 
269
  posthog.capture = lambda *args, **kwargs: None
270
  if args.chroma_host is None:
271
  if args.chroma_persist:
272
+ chromadb_client = chromadb.PersistentClient(path=args.chroma_folder, settings=Settings(anonymized_telemetry=False))
273
  print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
274
  else:
275
+ chromadb_client = chromadb.EphemeralClient(Settings(anonymized_telemetry=False))
276
  print(f"ChromaDB is running in-memory without persistence.")
277
  else:
278
  chroma_port=(
279
  args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
280
  )
281
+ chromadb_client = chromadb.HttpClient(host=args.chroma_host, port=chroma_port, settings=Settings(anonymized_telemetry=False))
 
 
 
 
 
 
 
282
  print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
283
 
284
+ chromadb_embedder = SentenceTransformer(embedding_model, device=device_string)
285
  chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
286
 
287
  # Check if the db is connected and running, otherwise tell the user
 
415
  image.save(buffer, format="JPEG", quality=quality)
416
  img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
417
  return img_str
 
 
418
 
419
+ ignore_auth = []
420
+ # Reads an API key from an already existing file. If that file doesn't exist, create it.
421
+ if args.secure:
422
+ try:
423
+ with open("api_key.txt", "r") as txt:
424
+ api_key = txt.read().replace('\n', '')
425
+ except:
426
+ api_key = secrets.token_hex(5)
427
+ with open("api_key.txt", "w") as txt:
428
+ txt.write(api_key)
429
+
430
+ print(f"Your API key is {api_key}")
431
+ elif args.share and args.secure != True:
432
+ print("WARNING: This instance is publicly exposed without an API key! It is highly recommended to restart with the \"--secure\" argument!")
433
+ else:
434
+ print("No API key given because you are running locally.")
435
+
436
 
437
  def is_authorize_ignored(request):
438
  view_func = app.view_functions.get(request.endpoint)
 
442
  return True
443
  return False
444
 
445
+
446
  @app.before_request
447
  def before_request():
448
  # Request time measuring
 
451
  # Checks if an API key is present and valid, otherwise return unauthorized
452
  # The options check is required so CORS doesn't get angry
453
  try:
454
+ if request.method != 'OPTIONS' and args.secure and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key:
455
  print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
456
  response = jsonify({ 'error': '401: Invalid API key' })
457
  response.status_code = 401
458
+ return response
459
  except Exception as e:
460
  print(f"API key check error: {e}")
461
+ return "401 Unauthorized\n{}\n\n".format(e), 401
462
 
463
 
464
  @app.after_request
 
670
  ]
671
  return jsonify(voices)
672
 
673
+ # Added fix for Silero not working as new files were unable to be created if one already existed. - Rolyat 7/7/23
674
  @app.route("/api/tts/generate", methods=["POST"])
675
  @require_module("silero-tts")
676
  def tts_generate():
 
682
  # Remove asterisks
683
  voice["text"] = voice["text"].replace("*", "")
684
  try:
685
+ # Remove the destination file if it already exists
686
+ if os.path.exists('test.wav'):
687
+ os.remove('test.wav')
688
+
689
  audio = tts_service.generate(voice["speaker"], voice["text"])
690
+ audio_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.basename(audio))
691
+
692
+ os.rename(audio, audio_file_path)
693
+ return send_file(audio_file_path, mimetype="audio/x-wav")
694
  except Exception as e:
695
  print(e)
696
  abort(500, voice["speaker"])
 
775
 
776
  count = collection.count()
777
  collection.delete()
 
 
778
  print("ChromaDB embeddings deleted", count)
779
  return 'Ok', 200
780
 
 
798
  name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
799
  )
800
 
801
+ if collection.count() == 0:
802
+ print(f"Queried empty/missing collection for {repr(data['chat_id'])}.")
803
+ return jsonify([])
804
+
805
+
806
  n_results = min(collection.count(), n_results)
807
  query_result = collection.query(
808
  query_texts=[data["query"]],
 
828
 
829
  return jsonify(messages)
830
 
831
+ @app.route("/api/chromadb/multiquery", methods=["POST"])
832
+ @require_module("chromadb")
833
+ def chromadb_multiquery():
834
+ data = request.get_json()
835
+ if "chat_list" not in data or not isinstance(data["chat_list"], list):
836
+ abort(400, '"chat_list" is required and should be a list')
837
+ if "query" not in data or not isinstance(data["query"], str):
838
+ abort(400, '"query" is required')
839
+
840
+ if "n_results" not in data or not isinstance(data["n_results"], int):
841
+ n_results = 1
842
+ else:
843
+ n_results = data["n_results"]
844
+
845
+ messages = []
846
+
847
+ for chat_id in data["chat_list"]:
848
+ if not isinstance(chat_id, str):
849
+ continue
850
+
851
+ try:
852
+ chat_id_md5 = hashlib.md5(chat_id.encode()).hexdigest()
853
+ collection = chromadb_client.get_collection(
854
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
855
+ )
856
+
857
+ # Skip this chat if the collection is empty
858
+ if collection.count() == 0:
859
+ continue
860
+
861
+ n_results_per_chat = min(collection.count(), n_results)
862
+ query_result = collection.query(
863
+ query_texts=[data["query"]],
864
+ n_results=n_results_per_chat,
865
+ )
866
+ documents = query_result["documents"][0]
867
+ ids = query_result["ids"][0]
868
+ metadatas = query_result["metadatas"][0]
869
+ distances = query_result["distances"][0]
870
+
871
+ chat_messages = [
872
+ {
873
+ "id": ids[i],
874
+ "date": metadatas[i]["date"],
875
+ "role": metadatas[i]["role"],
876
+ "meta": metadatas[i]["meta"],
877
+ "content": documents[i],
878
+ "distance": distances[i],
879
+ }
880
+ for i in range(len(ids))
881
+ ]
882
+
883
+ messages.extend(chat_messages)
884
+ except Exception as e:
885
+ print(e)
886
+
887
+ #remove duplicate msgs, filter down to the right number
888
+ seen = set()
889
+ messages = [d for d in messages if not (d['content'] in seen or seen.add(d['content']))]
890
+ messages = sorted(messages, key=lambda x: x['distance'])[0:n_results]
891
+
892
+ return jsonify(messages)
893
+
894
 
895
  @app.route("/api/chromadb/export", methods=["POST"])
896
  @require_module("chromadb")
 
900
  abort(400, '"chat_id" is required')
901
 
902
  chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
903
+ try:
904
+ collection = chromadb_client.get_collection(
905
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
906
+ )
907
+ except Exception as e:
908
+ print(e)
909
+ abort(400, "Chat collection not found in chromadb")
910
+
911
  collection_content = collection.get()
912
  documents = collection_content.get('documents', [])
913
  ids = collection_content.get('ids', [])
 
950
 
951
 
952
  collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
953
+ print(f"Imported {len(ids)} (total {collection.count()}) content entries into {repr(data['chat_id'])}")
954
 
955
  return jsonify({"count": len(ids)})
956
 
957
+
958
+ if args.share:
959
+ from flask_cloudflared import _run_cloudflared
960
+ import inspect
961
+
962
+ sig = inspect.signature(_run_cloudflared)
963
+ sum = sum(
964
+ 1
965
+ for param in sig.parameters.values()
966
+ if param.kind == param.POSITIONAL_OR_KEYWORD
967
+ )
968
+ if sum > 1:
969
+ metrics_port = randint(8100, 9000)
970
+ cloudflare = _run_cloudflared(port, metrics_port)
971
+ else:
972
+ cloudflare = _run_cloudflared(port)
973
+ print("Running on", cloudflare)
974
+
975
  ignore_auth.append(tts_play_sample)
976
  app.run(host=host, port=port)