Hjgugugjhuhjggg commited on
Commit
f5bef42
·
verified ·
1 Parent(s): 2784732

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -54
app.py CHANGED
@@ -1,78 +1,132 @@
 
 
 
1
  from llama_cpp import Llama
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
- import re
4
- import uvicorn
5
- from fastapi import FastAPI
6
- from fastapi.middleware.cors import CORSMiddleware
7
- import os
8
  from dotenv import load_dotenv
9
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
10
 
11
  load_dotenv()
 
 
12
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
 
 
13
 
14
- global_data = {'models': {}, 'tokens': {k: k + '_token' for k in ['eos', 'pad', 'padding', 'unk', 'bos', 'sep', 'cls', 'mask']}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- model_configs = [{"repo_id": "Hjgugugjhuhjggg/mergekit-ties-tzamfyy-Q2_K-GGUF", "filename": "mergekit-ties-tzamfyy-q2_k.gguf", "name": "my_model"}]
 
 
 
17
 
18
- models = {}
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def load_model(model_config):
21
- model_name = model_config['name']
22
- try:
23
- model = Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'], use_auth_token=HUGGINGFACE_TOKEN)
24
- models[model_name] = model
25
- global_data['models'] = models
26
- return model
27
- except Exception as e:
28
- print(f"Error loading model {model_name}: {e}")
29
- return None
30
 
31
- for config in model_configs:
32
- model = load_model(config)
33
- if model is None:
34
- exit(1)
35
 
36
  class ChatRequest(BaseModel):
37
  message: str
38
 
39
- def normalize_input(input_text):
40
- return input_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def remove_duplicates(text):
43
- lines = [line.strip() for line in text.split('\n') if line.strip()]
44
- return '\n'.join(dict.fromkeys(lines))
45
 
46
- def generate_model_response(model, inputs):
 
 
47
  try:
48
- if model is None:
49
- return "Model loading failed."
50
- response = model(inputs, max_tokens=512)
51
- return remove_duplicates(response['choices'][0]['text'])
 
 
 
 
52
  except Exception as e:
53
- print(f"Error generating response: {e}")
54
- return f"Error: {e}"
55
 
56
- app = FastAPI()
57
- origins = ["*"]
58
- app.add_middleware(
59
- CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"]
60
- )
61
-
62
- @app.post("/generate")
63
- async def generate(request: ChatRequest):
64
- inputs = normalize_input(request.message)
65
- chunk_size = 400
66
- chunks = [inputs[i:i + chunk_size] for i in range(0, len(inputs), chunk_size)]
67
- overall_response = ""
68
- for chunk in chunks:
69
- with ThreadPoolExecutor() as executor:
70
- futures = [executor.submit(generate_model_response, model, chunk) for model in models.values()]
71
- responses = [{'model': name, 'response': future.result()} for name, future in zip(models, as_completed(futures))]
72
- for response in responses:
73
- overall_response += f"**{response['model']}:**\n{response['response']}\n\n"
74
- return {"response": overall_response}
75
 
76
  if __name__ == "__main__":
77
- port = int(os.environ.get("PORT", 7860))
78
  uvicorn.run(app, host="0.0.0.0", port=port)
 
1
+ import os
2
+ import gc
3
+ import io
4
  from llama_cpp import Llama
5
  from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from fastapi import FastAPI, Request, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ from tqdm import tqdm
 
 
9
  from dotenv import load_dotenv
10
  from pydantic import BaseModel
11
+ from huggingface_hub import hf_hub_download, login
12
+ from nltk.tokenize import word_tokenize
13
+ from nltk.corpus import stopwords
14
+ from sklearn.feature_extraction.text import TfidfVectorizer
15
+ from sklearn.metrics.pairwise import cosine_similarity
16
+ import nltk
17
+
18
+ nltk.download('punkt')
19
+ nltk.download('stopwords')
20
 
21
  load_dotenv()
22
+
23
+ app = FastAPI()
24
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
25
+ if HUGGINGFACE_TOKEN:
26
+ login(token=HUGGINGFACE_TOKEN)
27
 
28
+ global_data = {
29
+ 'model_configs': [
30
+ {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "name": "GPT-2 XL"},
31
+ {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "name": "Gemma 2-27B"},
32
+ {"repo_id": "Ffftdtd5dtft/Phi-3-mini-128k-instruct-Q2_K-GGUF", "name": "Phi-3 Mini 128K Instruct"},
33
+ {"repo_id": "Ffftdtd5dtft/starcoder2-3b-Q2_K-GGUF", "name": "Starcoder2 3B"},
34
+ {"repo_id": "Ffftdtd5dtft/Qwen2-1.5B-Instruct-Q2_K-GGUF", "name": "Qwen2 1.5B Instruct"},
35
+ {"repo_id": "Ffftdtd5dtft/Mistral-Nemo-Instruct-2407-Q2_K-GGUF", "name": "Mistral Nemo Instruct 2407"},
36
+ {"repo_id": "Ffftdtd5dtft/Phi-3-mini-128k-instruct-IQ2_XXS-GGUF", "name": "Phi 3 Mini 128K Instruct XXS"},
37
+ {"repo_id": "Ffftdtd5dtft/TinyLlama-1.1B-Chat-v1.0-IQ1_S-GGUF", "name": "TinyLlama 1.1B Chat"},
38
+ {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-8B-Q2_K-GGUF", "name": "Meta Llama 3.1-8B"},
39
+ {"repo_id": "Ffftdtd5dtft/codegemma-2b-IQ1_S-GGUF", "name": "Codegemma 2B"},
40
+ ],
41
+ 'training_data': io.StringIO(),
42
+ }
43
 
44
+ class ModelManager:
45
+ def __init__(self):
46
+ self.models = {}
47
+ self.load_models()
48
 
49
+ def load_models(self):
50
+ for config in tqdm(global_data['model_configs'], desc="Loading models"):
51
+ model_name = config['name']
52
+ if model_name not in self.models:
53
+ try:
54
+ model_path = hf_hub_download(repo_id=config['repo_id'], use_auth_token=HUGGINGFACE_TOKEN)
55
+ model = Llama.from_file(model_path)
56
+ self.models[model_name] = model
57
+ except Exception as e:
58
+ self.models[model_name] = None
59
+ finally:
60
+ gc.collect()
61
 
62
+ def get_model(self, model_name: str):
63
+ return self.models.get(model_name)
 
 
 
 
 
 
 
 
64
 
65
+
66
+ model_manager = ModelManager()
 
 
67
 
68
  class ChatRequest(BaseModel):
69
  message: str
70
 
71
+ async def generate_model_response(model, inputs: str) -> str:
72
+ try:
73
+ if model:
74
+ response = model(inputs, max_tokens=150)
75
+ return response['choices'][0]['text'].strip()
76
+ else:
77
+ return "Model not loaded"
78
+ except Exception as e:
79
+ return f"Error: Could not generate a response. Details: {e}"
80
+
81
+ async def process_message(message: str) -> dict:
82
+ inputs = message.strip()
83
+ responses = {}
84
+
85
+ with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
86
+ futures = [executor.submit(generate_model_response, model_manager.get_model(config['name']), inputs) for config in global_data['model_configs'] if model_manager.get_model(config['name'])]
87
+ for i, future in enumerate(tqdm(as_completed(futures), total=len(futures), desc="Generating responses")):
88
+ try:
89
+ model_name = global_data['model_configs'][i]['name']
90
+ responses[model_name] = future.result()
91
+ except Exception as e:
92
+ responses[model_name] = f"Error processing {model_name}: {e}"
93
+
94
+ stop_words = set(stopwords.words('english'))
95
+ vectorizer = TfidfVectorizer(tokenizer=word_tokenize, stop_words=stop_words)
96
+ reference_text = message
97
+ response_texts = list(responses.values())
98
+ tfidf_matrix = vectorizer.fit_transform([reference_text] + response_texts)
99
+ similarities = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:])
100
+ best_response_index = similarities.argmax()
101
+ best_response_model = list(responses.keys())[best_response_index]
102
+ best_response_text = response_texts[best_response_index]
103
 
104
+ return {"best_response": {"model": best_response_model, "text": best_response_text}, "all_responses": responses}
 
 
105
 
106
+
107
+ @app.post("/generate_multimodel")
108
+ async def api_generate_multimodel(request: Request):
109
  try:
110
+ data = await request.json()
111
+ message = data.get("message")
112
+ if not message:
113
+ raise HTTPException(status_code=400, detail="Missing message")
114
+ response = await process_message(message)
115
+ return JSONResponse(response)
116
+ except HTTPException as e:
117
+ raise e
118
  except Exception as e:
119
+ return JSONResponse({"error": str(e)}, status_code=500)
 
120
 
121
+
122
+ @app.on_event("startup")
123
+ async def startup_event():
124
+ pass
125
+
126
+ @app.on_event("shutdown")
127
+ async def shutdown_event():
128
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
129
 
130
  if __name__ == "__main__":
131
+ port = int(os.environ.get("PORT", 8000))
132
  uvicorn.run(app, host="0.0.0.0", port=port)