cheesexuebao commited on
Commit
74b913c
1 Parent(s): 8185892

Modify tables

Browse files
Files changed (3) hide show
  1. Prediction.py +2 -1
  2. app.py +5 -6
  3. assets/Prediction.py.bak +4 -1
Prediction.py CHANGED
@@ -41,7 +41,7 @@ def predict_csv(data, text_col, tokenizer, model, device, text_bs=16, max_token_
41
  y_inten = final_pred.numpy().T
42
 
43
  for i in range(len(LABEL_COLUMNS)):
44
- data[LABEL_COLUMNS[i]] = y_inten[i].tolist()
45
  return data
46
 
47
  @torch.no_grad()
@@ -63,6 +63,7 @@ def predict_single(sentence, tokenizer, model, device, max_token_len=128):
63
  ).logits
64
  prediction = torch.softmax(logits, dim=1)
65
  y_inten = prediction.flatten().cpu().numpy().T.tolist()
 
66
  return y_inten
67
 
68
  def model_factory(local_path, device):
 
41
  y_inten = final_pred.numpy().T
42
 
43
  for i in range(len(LABEL_COLUMNS)):
44
+ data[LABEL_COLUMNS[i]] = [round(i, 8) for i in y_inten[i].tolist()]
45
  return data
46
 
47
  @torch.no_grad()
 
63
  ).logits
64
  prediction = torch.softmax(logits, dim=1)
65
  y_inten = prediction.flatten().cpu().numpy().T.tolist()
66
+ y_inten = [round(i, 8) for i in y_inten]
67
  return y_inten
68
 
69
  def model_factory(local_path, device):
app.py CHANGED
@@ -24,13 +24,12 @@ manager = model_factory("./models", device)
24
 
25
 
26
  def single_sentence(sentence):
27
- df = []
28
  model_name = 'All_Data'
29
  dct = manager[model_name]
30
  model, tokenizer = dct['model'], dct['tokenizer']
31
  predictions = predict_single(sentence, tokenizer, model, device)
32
- df.append(predictions)
33
- return df
34
 
35
  def csv_process(csv_file, attr="content"):
36
  current_time = datetime.now()
@@ -76,9 +75,9 @@ with gr.Blocks(theme=my_theme, title='Murphy') as demo:
76
  # Detailed information about our model:
77
  ...
78
  """)
79
- tab_output = gr.DataFrame(label='Probability Predictions:',
80
- headers=LABEL_COLUMNS,
81
- datatype=["str"] * (len(LABEL_COLUMNS)),
82
  interactive=False)
83
  with gr.Row():
84
  button_ss = gr.Button("Submit", variant="primary")
 
24
 
25
 
26
  def single_sentence(sentence):
 
27
  model_name = 'All_Data'
28
  dct = manager[model_name]
29
  model, tokenizer = dct['model'], dct['tokenizer']
30
  predictions = predict_single(sentence, tokenizer, model, device)
31
+ predictions.sort(reverse=True)
32
+ return list(zip(LABEL_COLUMNS, predictions))
33
 
34
  def csv_process(csv_file, attr="content"):
35
  current_time = datetime.now()
 
75
  # Detailed information about our model:
76
  ...
77
  """)
78
+ tab_output = gr.DataFrame(label='Predictions:',
79
+ headers=["Label", "Probability"],
80
+ datatype=["str", "number"],
81
  interactive=False)
82
  with gr.Row():
83
  button_ss = gr.Button("Submit", variant="primary")
assets/Prediction.py.bak CHANGED
@@ -126,4 +126,7 @@ kick_fk_doc_result = predict(Data,"content", tokenizer,kick_model, device, LABEL
126
 
127
  fk_result = get_result(Data, kick_fk_doc_result, LABEL_COLUMNS)
128
 
129
- fk_result.to_csv("output/prediction_origin_Kickstarter.csv")
 
 
 
 
126
 
127
  fk_result = get_result(Data, kick_fk_doc_result, LABEL_COLUMNS)
128
 
129
+ fk_result.to_csv("output/prediction_origin_Kickstarter.csv")
130
+
131
+
132
+ # tab_output = gr.Label(label='Probability Predictions:', value=dict(zip(LABEL_COLUMNS, [0]*len(LABEL_COLUMNS))))