StoneSeller commited on
Commit
62fa3ff
·
verified ·
1 Parent(s): dd8683c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -10
app.py CHANGED
@@ -1,28 +1,27 @@
1
  import subprocess
2
  import sys
3
 
4
- # Function to install a package
5
  def install(package):
6
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7
 
8
- # Ensure required libraries are installed
9
  try:
10
- import torch
11
  except ImportError:
12
- install("torch==2.0.1")
13
- install("torchvision==0.15.2")
14
 
15
  try:
16
- import numpy as np
17
  except ImportError:
18
- install("numpy<2")
 
19
 
20
  try:
21
  from PIL import Image
22
  except ImportError:
23
  install("Pillow==9.5.0")
24
 
25
- # Imports
26
  import torch
27
  import torch.nn as nn
28
  import torch.nn.functional as F
@@ -30,7 +29,6 @@ import torchvision.transforms as transforms
30
  from PIL import Image
31
  import gradio as gr
32
 
33
-
34
  # Define the model
35
  class ModifiedLargeNet(nn.Module):
36
  def __init__(self):
@@ -50,7 +48,6 @@ class ModifiedLargeNet(nn.Module):
50
  x = self.fc2(x)
51
  return x
52
 
53
-
54
  # Load the trained model
55
  model = ModifiedLargeNet()
56
  model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))
 
1
  import subprocess
2
  import sys
3
 
4
+ # Ensure required packages are installed
5
  def install(package):
6
  subprocess.check_call([sys.executable, "-m", "pip", "install", package])
7
 
 
8
  try:
9
+ import numpy as np
10
  except ImportError:
11
+ install("numpy<2")
 
12
 
13
  try:
14
+ import torch
15
  except ImportError:
16
+ install("torch==2.0.1")
17
+ install("torchvision==0.15.2")
18
 
19
  try:
20
  from PIL import Image
21
  except ImportError:
22
  install("Pillow==9.5.0")
23
 
24
+ # Import necessary libraries
25
  import torch
26
  import torch.nn as nn
27
  import torch.nn.functional as F
 
29
  from PIL import Image
30
  import gradio as gr
31
 
 
32
  # Define the model
33
  class ModifiedLargeNet(nn.Module):
34
  def __init__(self):
 
48
  x = self.fc2(x)
49
  return x
50
 
 
51
  # Load the trained model
52
  model = ModifiedLargeNet()
53
  model.load_state_dict(torch.load("modified_large_net.pt", map_location=torch.device("cpu")))