Skip to content

Commit

Permalink
Logit Transmission via gRPC: Relocating scores normalization to local…
Browse files Browse the repository at this point in the history
… client
  • Loading branch information
dotpyu authored Jul 5, 2024
1 parent fffd7a8 commit cbaed9e
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions alfred/fm/remote/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import socket
import torch.nn.functional as F
from concurrent import futures
from typing import Optional, Union, Iterable, Tuple, Any, List

Expand Down Expand Up @@ -122,11 +123,19 @@ def _run_req_gen():
output = []
for response in self.stub.Run(_run_req_gen()):
if response.ranked:
logits = ast.literal_eval(response.logit)
candidates = list(logits.keys())
logit_values = torch.tensor(list(logits.values()))
probabilities = F.softmax(logit_values, dim=0)
scores = {
candidate: prob.item() for candidate, prob in zip(candidates, probabilities)
}
output.append(
RankedResponse(
**{
"prediction": response.message,
"scores": ast.literal_eval(response.logit),
"scores": scores,
"logit": logits,
"embeddings": bytes_to_tensor(response.embedding),
}
)
Expand Down Expand Up @@ -241,7 +250,7 @@ def Run(self, request_iterator, context):
yield query_pb2.RunResponse(
message=response.prediction,
ranked=True,
logit=str(response.scores),
logit=str(response.logits),
embedding=tensor_to_bytes(response.embeddings),
)
else:
Expand Down

0 comments on commit cbaed9e

Please sign in to comment.