AndersonConforto commited on
Commit
39f30cb
·
1 Parent(s): dbb647b

first commit

Browse files
Files changed (3) hide show
  1. app.py +25 -11
  2. requirements.txt +1 -1
  3. surya_model.pt +0 -0
app.py CHANGED
@@ -1,13 +1,31 @@
1
  import gradio as gr
2
- from transformers import AutoModel
3
  import torch
4
  from PIL import Image
5
  import numpy as np
6
  import matplotlib.pyplot as plt
 
 
7
 
8
- # Carregar modelo Surya
9
- model_name = "nasa-ibm-ai4science/Surya-1.0"
10
- model = AutoModel.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  model.eval()
12
 
13
  # Função para gerar heatmap
@@ -19,19 +37,15 @@ def infer_solar_image_heatmap(img):
19
  with torch.no_grad():
20
  outputs = model(img_tensor)
21
 
22
- # Pegar os primeiros canais e reshapar para visualização
23
- emb = outputs[0].squeeze().numpy()
24
  heatmap = emb - emb.min()
25
  heatmap /= heatmap.max() + 1e-8 # normalização 0-1
26
 
27
- # Criar figura
28
  plt.imshow(heatmap, cmap='hot')
29
  plt.axis('off')
30
  plt.tight_layout()
31
-
32
- # Salvar figura em buffer
33
- fig = plt.gcf()
34
- return fig
35
 
36
  # Interface Gradio
37
  interface = gr.Interface(
 
1
  import gradio as gr
 
2
  import torch
3
  from PIL import Image
4
  import numpy as np
5
  import matplotlib.pyplot as plt
6
+ import requests
7
+ import os
8
 
9
+ # URLs dos arquivos do modelo
10
+ MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt"
11
+
12
+ # Nome local do arquivo
13
+ MODEL_FILE = "surya.366m.v1.pt"
14
+
15
+ # Função para baixar o modelo se não existir
16
+ def download_model():
17
+ if not os.path.exists(MODEL_FILE):
18
+ print("Baixando pesos do Surya-1.0...")
19
+ r = requests.get(MODEL_URL)
20
+ with open(MODEL_FILE, "wb") as f:
21
+ f.write(r.content)
22
+ print("Download concluído!")
23
+
24
+ # Baixar modelo
25
+ download_model()
26
+
27
+ # Carregar modelo PyTorch
28
+ model = torch.load(MODEL_FILE)
29
  model.eval()
30
 
31
  # Função para gerar heatmap
 
37
  with torch.no_grad():
38
  outputs = model(img_tensor)
39
 
40
+ # Criar heatmap
41
+ emb = outputs.squeeze().numpy()
42
  heatmap = emb - emb.min()
43
  heatmap /= heatmap.max() + 1e-8 # normalização 0-1
44
 
 
45
  plt.imshow(heatmap, cmap='hot')
46
  plt.axis('off')
47
  plt.tight_layout()
48
+ return plt.gcf()
 
 
 
49
 
50
  # Interface Gradio
51
  interface = gr.Interface(
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  torch
2
- transformers
3
  pillow
4
  numpy
5
  matplotlib
6
  gradio
 
 
1
  torch
 
2
  pillow
3
  numpy
4
  matplotlib
5
  gradio
6
+ requests
surya_model.pt ADDED
File without changes