Commit
·
b305b72
1
Parent(s):
d6ffbb8
Add guideline to use BiRefNet. Remove codes of model.
Browse files
README.md
CHANGED
|
@@ -31,9 +31,12 @@ import matplotlib.pyplot as plt
|
|
| 31 |
import torch
|
| 32 |
from torchvision import transforms
|
| 33 |
|
|
|
|
|
|
|
|
|
|
| 34 |
# Input Data
|
| 35 |
transform_image = transforms.Compose([
|
| 36 |
-
transforms.Resize((
|
| 37 |
transforms.ToTensor(),
|
| 38 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 39 |
])
|
|
@@ -42,7 +45,7 @@ image = Image.open(imagepath)
|
|
| 42 |
input_images = transform_image(image).unsqueeze(0).to('cuda')
|
| 43 |
|
| 44 |
# Load Model
|
| 45 |
-
device = '
|
| 46 |
torch.set_float32_matmul_precision(['high', 'highest'][0])
|
| 47 |
model = BiRefNet.from_pretrained('zhengpeng7/birefnet')
|
| 48 |
model.to(device)
|
|
@@ -55,7 +58,7 @@ with torch.no_grad():
|
|
| 55 |
pred = preds[0].squeeze()
|
| 56 |
|
| 57 |
# Show Results
|
| 58 |
-
plt.imshow(pred, cmap='gray')
|
| 59 |
plt.show()
|
| 60 |
|
| 61 |
```
|
|
|
|
| 31 |
import torch
|
| 32 |
from torchvision import transforms
|
| 33 |
|
| 34 |
+
from models.birefnet import BiRefNet
|
| 35 |
+
|
| 36 |
+
|
| 37 |
# Input Data
|
| 38 |
transform_image = transforms.Compose([
|
| 39 |
+
transforms.Resize((1024, 1024)),
|
| 40 |
transforms.ToTensor(),
|
| 41 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
| 42 |
])
|
|
|
|
| 45 |
input_images = transform_image(image).unsqueeze(0).to('cuda')
|
| 46 |
|
| 47 |
# Load Model
|
| 48 |
+
device = 'cuda'
|
| 49 |
torch.set_float32_matmul_precision(['high', 'highest'][0])
|
| 50 |
model = BiRefNet.from_pretrained('zhengpeng7/birefnet')
|
| 51 |
model.to(device)
|
|
|
|
| 58 |
pred = preds[0].squeeze()
|
| 59 |
|
| 60 |
# Show Results
|
| 61 |
+
plt.imshow(transforms.ToPILImage()(pred).resize(image.size), cmap='gray')
|
| 62 |
plt.show()
|
| 63 |
|
| 64 |
```
|