Skip to content

Commit

Permalink
Add program_id to op profile Node.xla
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733965194
  • Loading branch information
zzzaries authored and copybara-github committed Mar 6, 2025
1 parent 10721c7 commit 5f3249b
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ export namespace Node {
category?: string;
layout?: Node.XLAInstruction.LayoutAnalysis;
computationPrimitiveSize?: /* uint32 */ number;
programId?: /* uint64 */ string;
}
export namespace XLAInstruction {
export interface LayoutAnalysis {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
XLA Modules ({{ moduleList.length }})
</div>
</div>
<mat-form-field class="full-width" *ngIf="!isHloOssTool && !useProgramId" id="xla-module-list-selector" appearance="outline">
<mat-form-field class="full-width" *ngIf="!isHloOssTool" id="xla-module-list-selector" appearance="outline">
<mat-select
name="selectedModule"
panelClass="panel-override"
Expand All @@ -17,17 +17,6 @@
</mat-option>
</mat-select>
</mat-form-field>
<mat-form-field class="full-width" *ngIf="useProgramId">
<mat-label>XLA Program Id</mat-label>
<input
matInput
aria-label="program-id"
autocomplete="off"
name="programId"
[disabled]="true"
[value]="params.programId"
/>
</mat-form-field>
<mat-form-field *ngIf="graphTypes.length > 0" class="full-width" id="graph-type-selector" appearance="outline">
<mat-label>Graph Type</mat-label>
<mat-select
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ export class GraphConfig implements OnDestroy, OnChanges {
@Input() moduleList: string[] = [];
// Temparary indicator to hide the module name selection for 1vm graph viewer
@Input() isHloOssTool = false;
@Input() useProgramId = false;
@Input() graphTypes: GraphTypeObject[] = [];

inputsInited = false;
Expand All @@ -47,8 +46,10 @@ export class GraphConfig implements OnDestroy, OnChanges {
if (changes.hasOwnProperty('moduleList') &&
changes['moduleList'].currentValue.length > 0 &&
!this.params.selectedModule) {
this.params.selectedModule =
this.params.selectedModule || changes['moduleList'].currentValue[0];
const moduleList = changes['moduleList'].currentValue;
this.params.selectedModule = moduleList.find(
(module: string) => module.includes(this.params.programId || ''),
) || moduleList[0];
}

if (changes.hasOwnProperty('graphTypes') &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
</div>
</div>
<div class="info-content">
<div class="info" id="graph-viewer-link" [hidden]="!hasValidGraphViewerLink()">
<a [href]="getGraphViewerLinkWrapper()" target="_blank" style="text-decoration:none;">
<div class="info" id="graph-viewer-link" [hidden]="!graphViewerLink">
<a [href]="graphViewerLink" target="_blank" style="text-decoration:none;">
<button
mat-stroked-button
extended
Expand Down
56 changes: 23 additions & 33 deletions frontend/app/components/op_profile/op_details/op_details.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export class OpDetails {
bwColors: string[] =
Array.from<string>({length: utils.MemBwType.MEM_BW_TYPE_MAX + 1})
.fill('');
programId = '';
expression: string = '';
provenance: string = '';
rawTimeMs = '';
Expand Down Expand Up @@ -115,41 +116,38 @@ export class OpDetails {
}
}

hasValidGraphViewerLink() {
const aggregatedBy = this.selectedOpNodeChain[0];
if (aggregatedBy === 'by_category' && this.moduleList.length > 1) {
return false;
}
// Condition for both 'by_program' and 'by_category'
return this.selectedOpNodeChain.length >= 2 && this.expression;
}
// hasValidGraphViewerLink() {
// return this.selectedOpNodeChain.length >= 2 && this.selectedOpName &&
// (this.programId || this.selectedModuleName);
// }

// expression format assumption: '%<op_name> = ...'
getSelectedOpName() {
get selectedOpName() {
return this.expression.split('=')[0].trim().slice(1);
}

getSelectedModuleName() {
get selectedModuleName() {
const aggregatedBy = this.selectedOpNodeChain[0];
// 'by_program' or 'by_category'
return aggregatedBy === 'by_program' ? this.selectedOpNodeChain[1] :
this.moduleList[0];
return aggregatedBy === 'by_program' ? this.selectedOpNodeChain[1] : '';
}

getGraphViewerLinkWrapper() {
const moduleName = this.getSelectedModuleName();
const opName = this.getSelectedOpName();
return this.dataService.getGraphViewerLink(
this.sessionId,
moduleName,
opName,
);
get graphViewerLink() {
if (this.selectedModuleName) {
return this.dataService.getGraphViewerLink(
this.sessionId, this.selectedModuleName, this.selectedOpName, '');
}
if (this.programId) {
return this.dataService.getGraphViewerLink(
this.sessionId, '', this.selectedOpName, this.programId);
}
return '';
}

getCustomCallTextLink() {
return `/graph_viewer.json?session_id=${this.sessionId}&module_name=${
this.getSelectedModuleName()}&node_name=${
this.getSelectedOpName()}&type=custom_call`;
this.selectedModuleName}&node_name=${
this.selectedOpName}&type=custom_call`;
}

dimensionColor(dimension?: Node.XLAInstruction.LayoutAnalysis.Dimension):
Expand Down Expand Up @@ -245,17 +243,9 @@ export class OpDetails {
}
}

if (this.node.xla && this.node.xla.expression) {
this.expression = this.node.xla.expression;
} else {
this.expression = '';
}

if (this.node.xla && this.node.xla.provenance) {
this.provenance = this.node.xla.provenance;
} else {
this.provenance = '';
}
this.programId = this.node.xla?.programId || '';
this.expression = this.node.xla?.expression || '';
this.provenance = this.node.xla?.provenance || '';

if (this.node.metrics && this.node.metrics.rawTime) {
this.rawTimeMs = utils.humanReadableText(
Expand Down
12 changes: 8 additions & 4 deletions frontend/app/services/data_service_v2/data_service_v2.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ export class DataServiceV2 implements DataServiceV2Interface {
Observable<DataTable>;
}

getGraphViewerLink(sessionId: string, moduleName: string, opName: string) {
if (!moduleName || !opName) return '';
return `${window.parent.location.origin}?tool=graph_viewer&host=${
moduleName}&opName=${opName}&run=${sessionId}#profile`;
// Get graph with program id and op name is not implemented yet.
getGraphViewerLink(
sessionId: string, moduleName: string, opName: string, programId = '') {
if (moduleName && opName) {
return `${window.parent.location.origin}?tool=graph_viewer&host=${
moduleName}&opName=${opName}&run=${sessionId}#profile`;
}
return '';
}

getOpProfileSummary(data: OpProfileData): OpProfileSummary[] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ export interface DataServiceV2Interface {
ignoreError?: boolean,
): Observable<DataTable|null>;

getGraphViewerLink(sessionId: string, moduleName: string, opName: string):
string;
getGraphViewerLink(
sessionId: string,
moduleName: string,
opName: string,
programId: string,
): string;

getOpProfileSummary(data: OpProfileData): OpProfileSummary[];
}
Expand Down
24 changes: 24 additions & 0 deletions plugin/tensorboard_plugin_profile/profile_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
RUNS_ROUTE = '/runs'
RUN_TOOLS_ROUTE = '/run_tools'
HOSTS_ROUTE = '/hosts'
HLO_MODULE_LIST_ROUTE = '/module_list'
CAPTURE_ROUTE = '/capture_profile'

# Suffixes of "^, #, @" symbols represent different input data formats for the
Expand Down Expand Up @@ -455,6 +456,7 @@ def get_plugin_apps(
RUN_TOOLS_ROUTE: self.run_tools_route,
HOSTS_ROUTE: self.hosts_route,
DATA_ROUTE: self.data_route,
HLO_MODULE_LIST_ROUTE: self.hlo_module_list_route,
CAPTURE_ROUTE: self.capture_route,
}

Expand Down Expand Up @@ -694,6 +696,28 @@ def data_impl(

return get_data_content_encoding(raw_data, tool, params)

def hlo_module_list_route(
self, request: wrappers.Request
) -> wrappers.Response:
"""Returns a list of HLO module names for the given run."""
run = request.args.get('run')
run_dir = self._run_dir(run)
module_list = []
if run_dir:
module_pattern = make_filename('*', 'hlo_proto.pb')
try:
path = epath.Path(run_dir)
module_paths = path.glob(module_pattern)
except OSError as e:
logger.warning(
'Cannot read asset directory: %s, OpError %s', run_dir, e
)
raise IOError(
'Cannot read asset directory: %s, OpError %s' % (run_dir, e)
) from e
module_list = [os.fspath(os.path.basename(f)) for f in module_paths]
return respond(module_list, 'application/json')

# pytype: disable=wrong-arg-types
@wrappers.Request.application
def data_route(self, request: wrappers.Request) -> wrappers.Response:
Expand Down

0 comments on commit 5f3249b

Please sign in to comment.