phk0 commited on
Commit
611c4ac
·
verified ·
1 Parent(s): b0cf41e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +31 -31
main.py CHANGED
@@ -7,59 +7,59 @@
7
  # "history": [],
8
  # "system_prompt": "You are a very powerful AI assistant."
9
  # }' \
10
- # https://iiced-mixtral-46-7b-fastapi.hf.space/generate/
11
-
12
 
13
  from fastapi import FastAPI
14
  from pydantic import BaseModel
15
- from huggingface_hub import InferenceClient
16
  import uvicorn
 
17
 
18
 
19
  app = FastAPI()
20
 
21
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
22
-
23
  class Item(BaseModel):
24
  prompt: str
25
  history: list
26
  system_prompt: str
27
  temperature: float = 0.0
28
- max_new_tokens: int = 1048
29
  top_p: float = 0.15
30
  repetition_penalty: float = 1.0
31
 
32
- def format_prompt(message, history):
33
- prompt = "<s>"
34
  for user_prompt, bot_response in history:
35
- prompt += f"[INST] {user_prompt} [/INST]"
36
- prompt += f" {bot_response}</s> "
37
- prompt += f"[INST] {message} [/INST]"
38
  return prompt
39
 
40
  def generate(item: Item):
41
- temperature = float(item.temperature)
42
- if temperature < 1e-2:
43
- temperature = 1e-2
44
- top_p = float(item.top_p)
45
-
46
- generate_kwargs = dict(
47
- temperature=temperature,
48
- max_new_tokens=item.max_new_tokens,
49
- top_p=top_p,
50
- repetition_penalty=item.repetition_penalty,
51
- do_sample=True,
52
- seed=42,
53
- )
 
 
 
 
 
54
 
55
- formatted_prompt = format_prompt(f"{item.system_prompt}, {item.prompt}", item.history)
56
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
57
- output = ""
58
-
59
- for response in stream:
60
- output += response.token.text
61
- return output
62
 
63
  @app.post("/generate/")
64
  async def generate_text(item: Item):
65
  return {"response": generate(item)}
 
 
 
 
 
7
  # "history": [],
8
  # "system_prompt": "You are a very powerful AI assistant."
9
  # }' \
10
+ # https://phk0-bai.hf.space/generate/
 
11
 
12
  from fastapi import FastAPI
13
  from pydantic import BaseModel
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
  import uvicorn
16
+ import torch
17
 
18
 
19
  app = FastAPI()
20
 
 
 
21
  class Item(BaseModel):
22
  prompt: str
23
  history: list
24
  system_prompt: str
25
  temperature: float = 0.0
26
+ max_new_tokens: int = 900
27
  top_p: float = 0.15
28
  repetition_penalty: float = 1.0
29
 
30
+ def format_prompt(system, message, history):
31
+ prompt = [{"role": "system", "content": system}]
32
  for user_prompt, bot_response in history:
33
+ prompt += {"role": "user", "content": user_prompt}
34
+ prompt += {"role": "assistant", "content": bot_response}
35
+ prompt += {"role": "user", "content": message}
36
  return prompt
37
 
38
  def generate(item: Item):
39
+ device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ model_path = "ibm-granite/granite-34b-code-instruct-8k"
41
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
42
+ # drop device_map if running on CPU
43
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)
44
+ model.eval()
45
+ # change input text as desired
46
+ chat = format_prompt(item.system_prompt, item.prompt, item.history)
47
+ chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
48
+ # tokenize the text
49
+ input_tokens = tokenizer(chat, return_tensors="pt")
50
+ # transfer tokenized inputs to the device
51
+ for i in input_tokens:
52
+ input_tokens[i] = input_tokens[i].to(device)
53
+ # generate output tokens
54
+ output = model.generate(**input_tokens, max_new_tokens=900)
55
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
56
+ return output_text
57
 
 
 
 
 
 
 
 
58
 
59
  @app.post("/generate/")
60
  async def generate_text(item: Item):
61
  return {"response": generate(item)}
62
+
63
+ @app.get("/")
64
+ async def generate_text_root(item: Item):
65
+ return {"response": "try entry point: /generate/"}