Ajay12345678980 commited on
Commit
989a272
·
verified ·
1 Parent(s): 2e257e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -16
app.py CHANGED
@@ -7,18 +7,16 @@ from transformers import GPT2LMHeadModel, GPT2Tokenizer
7
  model_repo_id = "Ajay12345678980/QA_bot" # Replace with your model repository ID
8
 
9
  # Initialize the model and tokenizer
10
- model = GPT2LMHeadModel.from_pretrained(model_repo_id)
 
11
  tokenizer = GPT2Tokenizer.from_pretrained(model_repo_id)
12
 
13
  # Define the prediction function
14
  def generate_answer(question):
15
- input_ids = tokenizer.encode(question, return_tensors="pt").to("cuda")
16
-
17
- # Create the attention mask and pad token id
18
- attention_mask = torch.ones_like(input_ids).to("cuda")
19
  pad_token_id = tokenizer.eos_token_id
20
 
21
- #output = model[0].generate(
22
  output = model.generate(
23
  input_ids,
24
  max_new_tokens=100,
@@ -30,17 +28,14 @@ def generate_answer(question):
30
  start_index = decoded_output.find("Answer")
31
  end_index = decoded_output.find("<ANSWER_ENDED>")
32
 
33
- if end_index != -1:
34
- # Extract the text between "Answer" and "<ANSWER_ENDED>"
35
- answer_text = decoded_output[start_index + len("Answer"):end_index].strip()
 
 
36
  return answer_text
37
  else:
38
- # If "<ANSWER_ENDED>" is not found, return the text following "Answer"
39
- answer_text = decoded_output[start_index + len("Answer"):].strip()
40
- return answer_text
41
-
42
- #return tokenizer.decode(output[0], skip_special_tokens=True)
43
- #return tokenizer.decode(output, skip_special_tokens=True)
44
 
45
  # Gradio interface setup
46
  interface = gr.Interface(
@@ -48,7 +43,7 @@ interface = gr.Interface(
48
  inputs="text",
49
  outputs="text",
50
  title="GPT-2 Text Generation",
51
- description="Enter some text and see what the model generates!"
52
  )
53
 
54
  # Launch the Gradio app
 
7
  model_repo_id = "Ajay12345678980/QA_bot" # Replace with your model repository ID
8
 
9
  # Initialize the model and tokenizer
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model = GPT2LMHeadModel.from_pretrained(model_repo_id).to(device)
12
  tokenizer = GPT2Tokenizer.from_pretrained(model_repo_id)
13
 
14
  # Define the prediction function
15
  def generate_answer(question):
16
+ input_ids = tokenizer.encode(question, return_tensors="pt").to(device)
17
+ attention_mask = torch.ones_like(input_ids).to(device)
 
 
18
  pad_token_id = tokenizer.eos_token_id
19
 
 
20
  output = model.generate(
21
  input_ids,
22
  max_new_tokens=100,
 
28
  start_index = decoded_output.find("Answer")
29
  end_index = decoded_output.find("<ANSWER_ENDED>")
30
 
31
+ if start_index != -1:
32
+ if end_index != -1:
33
+ answer_text = decoded_output[start_index + len("Answer"):end_index].strip()
34
+ else:
35
+ answer_text = decoded_output[start_index + len("Answer"):].strip()
36
  return answer_text
37
  else:
38
+ return "Sorry, I couldn't generate an answer."
 
 
 
 
 
39
 
40
  # Gradio interface setup
41
  interface = gr.Interface(
 
43
  inputs="text",
44
  outputs="text",
45
  title="GPT-2 Text Generation",
46
+ description="Enter a question and see what the model generates!"
47
  )
48
 
49
  # Launch the Gradio app