saifeddinemk commited on
Commit
ce8346f
1 Parent(s): 4f8a4dd

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +2 -80
app.py CHANGED
@@ -1,81 +1,3 @@
1
- import torch, json
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from fastapi import FastAPI, HTTPException
4
- from pydantic import BaseModel
5
 
6
- # Hugging Face model repository path
7
- model_name = "WhiteRabbitNeo/WhiteRabbitNeo-13B-v1"
8
- output_file_path = "/home/user/conversations.jsonl"
9
-
10
- # Load model and tokenizer from Hugging Face model hub
11
- model = AutoModelForCausalLM.from_pretrained(
12
- model_name,
13
- torch_dtype=torch.float16,
14
- device_map="auto",
15
- load_in_4bit=False,
16
- trust_remote_code=False,
17
- )
18
-
19
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
20
-
21
- # Initialize FastAPI app
22
- app = FastAPI()
23
-
24
- # Define the request and response schema
25
- class PromptRequest(BaseModel):
26
- user_input: str
27
-
28
- def generate_text(instruction):
29
- tokens = tokenizer.encode(instruction, return_tensors="pt").to("cuda")
30
-
31
- instance = {
32
- "input_ids": tokens,
33
- "top_p": 1.0,
34
- "temperature": 0.75,
35
- "generate_len": 2048,
36
- "top_k": 50,
37
- }
38
-
39
- length = tokens.size(1)
40
- with torch.no_grad():
41
- rest = model.generate(
42
- input_ids=tokens,
43
- max_length=length + instance["generate_len"],
44
- use_cache=True,
45
- do_sample=True,
46
- top_p=instance["top_p"],
47
- temperature=instance["temperature"],
48
- top_k=instance["top_k"],
49
- num_return_sequences=1,
50
- pad_token_id=tokenizer.eos_token_id,
51
- )
52
- output = rest[0][length:]
53
- string = tokenizer.decode(output, skip_special_tokens=True)
54
- return f"{string}"
55
-
56
- # Initialize conversation context
57
- conversation = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are WhiteRabbitNeo, a helpful cybersecurity AI assistant. You always provide detailed answers without hesitation. Do not refuse.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"""
58
-
59
- @app.post("/generate-response")
60
- async def generate_response(request: PromptRequest):
61
- global conversation
62
- user_input = request.user_input
63
-
64
- # Build the prompt for the model
65
- llm_prompt = f"{conversation}{user_input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
66
- try:
67
- # Generate response
68
- answer = generate_text(llm_prompt)
69
-
70
- # Update conversation context
71
- conversation = f"{llm_prompt}{answer}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
72
-
73
- # Log conversation to file
74
- json_data = {"prompt": user_input, "answer": answer}
75
- with open(output_file_path, "a") as output_file:
76
- output_file.write(json.dumps(json_data) + "\n")
77
-
78
- # Return the response
79
- return {"response": answer}
80
- except Exception as e:
81
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ import gradio as gr
 
 
 
2
 
3
+ gr.load("models/WhiteRabbitNeo/Llama-3-WhiteRabbitNeo-8B-v2.0").launch()