ShAnSantosh's picture
Update app.py
8a41460
Raw
History Blame Contribute Delete
1.74 kB
import albumentations
import cv2
import torch
import timm
import gradio as gr
import numpy as np
import os
import random
device = torch.device('cpu')
labels = {
0: 'bacterial_leaf_blight',
1: 'bacterial_leaf_streak',
2: 'bacterial_panicle_blight',
3: 'blast',
4: 'brown_spot',
5: 'dead_heart',
6: 'downy_mildew',
7: 'hispa',
8: 'normal',
9: 'tungro'
}
def inference_fn(model, image=None):
model.eval()
image = image.to(device)
with torch.no_grad():
output = model(image.unsqueeze(0))
out = output.sigmoid().detach().cpu().numpy().flatten()
return out
def predict(image=None) -> dict:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
augmentations = albumentations.Compose(
[
albumentations.Resize(256, 256),
albumentations.HorizontalFlip(p=0.5),
albumentations.VerticalFlip(p=0.5),
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
]
)
augmented = augmentations(image=image)
image = augmented["image"]
image = np.transpose(image, (2, 0, 1))
image = torch.tensor(image, dtype=torch.float32)
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device)))
model.to(device)
predicted = inference_fn(model, image)
return {labels[i]: float(predicted[i]) for i in range(10)}
gr.Interface(fn=predict,
inputs=gr.inputs.Image(),
outputs=gr.outputs.Label(num_top_classes=10),
examples=["200005.jpg", "200006.jpg"], interpretation='default').launch()